Skip to content

Commit 1698963

Browse files
WaVEVtimgraham
authored andcommitted
fix collisions with projected columns and order by columns
1 parent cbee6d2 commit 1698963

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

django_mongodb/compiler.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,7 @@ def build_query(self, columns=None):
357357
if columns is None:
358358
extra_fields += ordering_fields
359359
if extra_fields:
360-
query.extra_fields = {
361-
field_name: expr.as_mql(self, self.connection) for field_name, expr in extra_fields
362-
}
360+
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
363361
where = self.get_where()
364362
try:
365363
expr = where.as_mql(self, self.connection) if where else {}
@@ -429,16 +427,18 @@ def _get_aggregate_expressions(self, expr):
429427
elif hasattr(expr, "get_source_expressions"):
430428
stack.extend(expr.get_source_expressions())
431429

432-
def get_project_fields(self, columns=None, ordering=None):
430+
def get_project_fields(self, columns=None, ordering=None, force_expression=False):
431+
if not columns:
432+
return {}
433433
fields = defaultdict(dict)
434-
for name, expr in columns or []:
434+
for name, expr in columns + (ordering or ()):
435435
collection = expr.alias if isinstance(expr, Col) else None
436436
try:
437437
fields[collection][name] = (
438438
1
439439
# For brevity/simplicity, project {"field_name": 1}
440440
# instead of {"field_name": "$field_name"}.
441-
if isinstance(expr, Col) and name == expr.target.column
441+
if isinstance(expr, Col) and name == expr.target.column and not force_expression
442442
else expr.as_mql(self, self.connection)
443443
)
444444
except EmptyResultSet:
@@ -449,9 +449,6 @@ def get_project_fields(self, columns=None, ordering=None):
449449
# should appear in the top-level of the fields dict.
450450
fields.update(fields.pop(None, {}))
451451
fields.update(fields.pop(self.collection_name, {}))
452-
# Add order_by() fields.
453-
if fields and ordering:
454-
fields.update({alias: expr.as_mql(self, self.connection) for alias, expr in ordering})
455452
# Convert defaultdict to dict so it doesn't appear as
456453
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
457454
return dict(fields)
@@ -471,7 +468,7 @@ def _get_ordering(self):
471468
for order in self.order_by_objs or []:
472469
if isinstance(order.expression, Col):
473470
field_name = order.expression.as_mql(self, self.connection).removeprefix("$")
474-
fields[field_name] = order.expression
471+
fields[order.expression.target.column] = order.expression
475472
elif isinstance(order.expression, Ref):
476473
field_name = order.expression.as_mql(self, self.connection).removeprefix("$")
477474
else:

django_mongodb/features.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
6363
"aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate",
6464
"aggregation_regress.tests.AggregationTests.test_annotation_disjunction",
6565
"aggregation_regress.tests.AggregationTests.test_decimal_aggregate_annotation_filter",
66-
# Invalid $project :: caused by :: Path collision at aggregation_regress_publisher.name
67-
"aggregation_regress.tests.AggregationTests.test_values_list_annotation_args_ordering",
6866
# QuerySet.extra(select=...) should raise NotSupportedError instead of:
6967
# 'RawSQL' object has no attribute 'as_mql'.
7068
"aggregation_regress.tests.AggregationTests.test_annotate_with_extra",

0 commit comments

Comments
 (0)