Skip to content

Commit bdd2a40

Browse files
authored
fix projection of refs in foreign collections (#99)
Handle refs from foreign collections.
1 parent 7d5d85a commit bdd2a40

File tree

3 files changed

+32
-35
lines changed

3 files changed

+32
-35
lines changed

django_mongodb/compiler.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def _build_aggregation_pipeline(self, ids, group):
193193
projected_fields[table][field] = value
194194
else:
195195
projected_fields[key] = value
196-
pipeline.append({"$addFields": projected_fields})
196+
# Convert defaultdict to dict so it doesn't appear as
197+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
198+
pipeline.append({"$addFields": dict(projected_fields)})
197199
if "_id" not in projected_fields:
198200
pipeline.append({"$unset": "_id"})
199201
return pipeline
@@ -428,33 +430,31 @@ def _get_aggregate_expressions(self, expr):
428430
stack.extend(expr.get_source_expressions())
429431

430432
def get_project_fields(self, columns=None, ordering=None):
431-
fields = {}
433+
fields = defaultdict(dict)
432434
for name, expr in columns or []:
435+
collection = expr.alias if isinstance(expr, Col) else None
433436
try:
434-
column = expr.target.column
435-
except AttributeError:
436-
# Generate the MQL for an annotation.
437-
try:
438-
fields[name] = expr.as_mql(self, self.connection)
439-
except EmptyResultSet:
440-
fields[name] = Value(False).as_mql(self, self.connection)
441-
except FullResultSet:
442-
fields[name] = Value(True).as_mql(self, self.connection)
443-
else:
444-
# If name != column, then this is an annotatation referencing
445-
# another column.
446-
fields[name] = 1 if name == column else f"${column}"
447-
if fields:
448-
# Add related fields.
449-
for alias in self.query.alias_map:
450-
if self.query.alias_refcount[alias] and self.collection_name != alias:
451-
fields[alias] = 1
452-
# Add order_by() fields.
453-
for alias, expression in ordering or []:
454-
nested_entity = alias.split(".", 1)[0] if "." in alias else None
455-
if alias not in fields and nested_entity not in fields:
456-
fields[alias] = expression.as_mql(self, self.connection)
457-
return fields
437+
fields[collection][name] = (
438+
1
439+
# For brevity/simplicity, project {"field_name": 1}
440+
# instead of {"field_name": "$field_name"}.
441+
if isinstance(expr, Col) and name == expr.target.column
442+
else expr.as_mql(self, self.connection)
443+
)
444+
except EmptyResultSet:
445+
fields[collection][name] = Value(False).as_mql(self, self.connection)
446+
except FullResultSet:
447+
fields[collection][name] = Value(True).as_mql(self, self.connection)
448+
# Annotations (stored in None) and the main collection's fields
449+
# should appear in the top-level of the fields dict.
450+
fields.update(fields.pop(None, {}))
451+
fields.update(fields.pop(self.collection_name, {}))
452+
# Add order_by() fields.
453+
if fields and ordering:
454+
fields.update({alias: expr.as_mql(self, self.connection) for alias, expr in ordering})
455+
# Convert defaultdict to dict so it doesn't appear as
456+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
457+
return dict(fields)
458458

459459
def _get_ordering(self):
460460
"""

django_mongodb/expressions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ def query(self, compiler, connection): # noqa: ARG001
7777

7878

7979
def ref(self, compiler, connection): # noqa: ARG001
80-
return f"${self.refs}"
80+
prefix = (
81+
f"{self.source.alias}."
82+
if isinstance(self.source, Col) and self.source.alias != compiler.collection_name
83+
else ""
84+
)
85+
return f"${prefix}{self.refs}"
8186

8287

8388
def star(self, compiler, connection): # noqa: ARG001

django_mongodb/features.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
4949
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation",
5050
# Length of null considered zero rather than null.
5151
"db_functions.text.test_length.LengthTests.test_basic",
52-
# Wrong annotation names in $project (dollar prefixed) when querying
53-
# multiple collections.
54-
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_reverse_m2m",
55-
"annotations.tests.NonAggregateAnnotationTestCase.test_annotation_with_m2m",
56-
"annotations.tests.NonAggregateAnnotationTestCase.test_chaining_annotation_filter_with_m2m",
57-
"annotations.tests.NonAggregateAnnotationTestCase.test_mti_annotations",
58-
"expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression",
59-
"expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression_flat",
6052
# range lookup includes incorrect values.
6153
"expressions.tests.IterableLookupInnerExpressionsTests.test_expressions_in_lookups_join_choice",
6254
# Unexpected alias_refcount in alias_map.

0 commit comments

Comments
 (0)