Skip to content

Commit 41fba65

Browse files
committed
Refactor projection and expressions.
1 parent a58e5e0 commit 41fba65

File tree

3 files changed

+40
-38
lines changed

3 files changed

+40
-38
lines changed

django_mongodb/compiler.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,19 @@ def __init__(self, *args, **kwargs):
3434
def _get_group_alias_column(self, expr, annotation_group_idx):
3535
"""Generate a dummy field for use in the ids fields in $group."""
3636
replacement = None
37-
if isinstance(expr, Col):
38-
col = expr
37+
38+
# Unwrap Ref (in this part of the pipeline we can't do references over $projected fields).
39+
expr_ = expr
40+
while isinstance(expr_, Ref):
41+
expr_ = expr_.source
42+
replacement = expr_
43+
44+
if isinstance(expr_, Col):
45+
col = expr_
3946
else:
4047
# If the column is a composite expression, create a field for it.
4148
alias = f"__annotation_group{next(annotation_group_idx)}"
42-
col = self._get_column_from_expression(expr, alias)
49+
col = self._get_column_from_expression(expr_, alias)
4350
replacement = col
4451
if self.collection_name == col.alias:
4552
return col.target.column, replacement
@@ -137,14 +144,21 @@ def _get_group_id_expressions(self, order_by):
137144
for expr, (_, _, is_ref) in order_by:
138145
if not is_ref:
139146
group_expressions |= set(expr.get_group_by_cols())
140-
for expr, *_ in self.select:
141-
group_expressions |= set(expr.get_group_by_cols())
142147
having_group_by = self.having.get_group_by_cols() if self.having else ()
143148
for expr in having_group_by:
144149
group_expressions.add(expr)
150+
151+
refs_viewed = set()
152+
for expr, _, alias in self.select:
153+
group_expressions |= set(expr.get_group_by_cols())
154+
refs_viewed.add(alias)
145155
if isinstance(self.query.group_by, tuple | list):
146-
group_expressions |= set(self.query.group_by)
147-
elif self.query.group_by is None:
156+
for expr in self.query.group_by:
157+
if not isinstance(expr, Ref) or expr.refs not in refs_viewed:
158+
group_expressions.add(expr)
159+
if isinstance(expr, Ref):
160+
refs_viewed.add(expr.refs)
161+
if self.query.group_by is None:
148162
group_expressions = set()
149163
if not group_expressions:
150164
ids = None
@@ -428,32 +442,22 @@ def _get_aggregate_expressions(self, expr):
428442
stack.extend(expr.get_source_expressions())
429443

430444
def get_project_fields(self, columns=None, ordering=None):
431-
fields = {}
445+
fields = defaultdict(dict)
432446
for name, expr in columns or []:
447+
collection = expr.alias if isinstance(expr, Col) else None
433448
try:
434-
column = expr.target.column
435-
except AttributeError:
436-
# Generate the MQL for an annotation.
437-
try:
438-
fields[name] = expr.as_mql(self, self.connection)
439-
except EmptyResultSet:
440-
fields[name] = Value(False).as_mql(self, self.connection)
441-
except FullResultSet:
442-
fields[name] = Value(True).as_mql(self, self.connection)
443-
else:
444-
# If name != column, then this is an annotatation referencing
445-
# another column.
446-
fields[name] = 1 if name == column else f"${column}"
447-
if fields:
448-
# Add related fields.
449-
for alias in self.query.alias_map:
450-
if self.query.alias_refcount[alias] and self.collection_name != alias:
451-
fields[alias] = 1
449+
fields[collection][name] = expr.as_mql(self, self.connection)
450+
except EmptyResultSet:
451+
fields[collection][name] = Value(False).as_mql(self, self.connection)
452+
except FullResultSet:
453+
fields[collection][name] = Value(True).as_mql(self, self.connection)
454+
# Unwrap annotations.
455+
fields.update(fields.pop(None, {}))
456+
# Unwrap main collection's fields.
457+
fields.update(fields.pop(self.collection_name, {}))
458+
if fields and ordering:
452459
# Add order_by() fields.
453-
for alias, expression in ordering or []:
454-
nested_entity = alias.split(".", 1)[0] if "." in alias else None
455-
if alias not in fields and nested_entity not in fields:
456-
fields[alias] = expression.as_mql(self, self.connection)
460+
fields.update({alias: expr.as_mql(self, self.connection) for alias, expr in ordering})
457461
return fields
458462

459463
def _get_ordering(self):

django_mongodb/expressions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def query(self, compiler, connection): # noqa: ARG001
7777

7878

7979
def ref(self, compiler, connection): # noqa: ARG001
80-
return f"${self.refs}"
80+
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name:
81+
prefix = f"{self.source.alias}."
82+
else:
83+
prefix = ""
84+
return f"${prefix}{self.refs}"
8185

8286

8387
def star(self, compiler, connection): # noqa: ARG001

django_mongodb/features.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5050
# Length of null considered zero rather than null.
5151
"db_functions.text.test_length.LengthTests.test_basic",
5252
# Wrong results in queries with multiple tables.
53-
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_reverse_m2m",
54-
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_with_m2m",
55-
"annotations.tests.NonAggregateAnnotationTestCase.test_chaining_annotation_filter_with_m2m",
56-
"annotations.tests.NonAggregateAnnotationTestCase.test_mti_annotations",
57-
"expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression",
58-
"expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression_flat",
5953
"expressions.tests.IterableLookupInnerExpressionsTests.test_expressions_in_lookups_join_choice",
6054
"queries.tests.Queries1Tests.test_order_by_tables",
61-
"queries.tests.TestTicket24605.test_ticket_24605",
6255
"queries.tests.TestInvalidValuesRelation.test_invalid_values",
6356
# QuerySet.explain() not implemented:
6457
# https://github.com/mongodb-labs/django-mongodb/issues/28
@@ -254,6 +247,7 @@ def django_test_expected_failures(self):
254247
"queries.tests.Queries6Tests.test_tickets_8921_9188",
255248
"queries.tests.Queries6Tests.test_xor_subquery",
256249
"queries.tests.QuerySetBitwiseOperationTests.test_subquery_aliases",
250+
"queries.tests.TestTicket24605.test_ticket_24605",
257251
"queries.tests.Ticket20101Tests.test_ticket_20101",
258252
"queries.tests.Ticket20788Tests.test_ticket_20788",
259253
"queries.tests.Ticket22429Tests.test_ticket_22429",

0 commit comments

Comments
 (0)