From 552261ac3a93365cd538bdaa0440db9e6734cc2b Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 13 Aug 2024 01:24:24 -0300 Subject: [PATCH 1/2] fix projection of refs in foreign collections --- django_mongodb/compiler.py | 48 +++++++++++++++++------------------ django_mongodb/expressions.py | 7 ++++- django_mongodb/features.py | 8 ------ 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 2c5cddd13..bb1ac4745 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -428,33 +428,31 @@ def _get_aggregate_expressions(self, expr): stack.extend(expr.get_source_expressions()) def get_project_fields(self, columns=None, ordering=None): - fields = {} + fields = defaultdict(dict) for name, expr in columns or []: + collection = expr.alias if isinstance(expr, Col) else None try: - column = expr.target.column - except AttributeError: - # Generate the MQL for an annotation. - try: - fields[name] = expr.as_mql(self, self.connection) - except EmptyResultSet: - fields[name] = Value(False).as_mql(self, self.connection) - except FullResultSet: - fields[name] = Value(True).as_mql(self, self.connection) - else: - # If name != column, then this is an annotatation referencing - # another column. - fields[name] = 1 if name == column else f"${column}" - if fields: - # Add related fields. - for alias in self.query.alias_map: - if self.query.alias_refcount[alias] and self.collection_name != alias: - fields[alias] = 1 - # Add order_by() fields. - for alias, expression in ordering or []: - nested_entity = alias.split(".", 1)[0] if "." in alias else None - if alias not in fields and nested_entity not in fields: - fields[alias] = expression.as_mql(self, self.connection) - return fields + fields[collection][name] = ( + 1 + # For brevity/simplicity, project {"field_name": 1} + # instead of {"field_name": "$field_name"}. + if isinstance(expr, Col) and name == expr.target.column + else expr.as_mql(self, self.connection) + ) + except EmptyResultSet: + fields[collection][name] = Value(False).as_mql(self, self.connection) + except FullResultSet: + fields[collection][name] = Value(True).as_mql(self, self.connection) + # Annotations (stored in None) and the main collection's fields + # should appear in the top-level of the fields dict. + fields.update(fields.pop(None, {})) + fields.update(fields.pop(self.collection_name, {})) + # Add order_by() fields. + if fields and ordering: + fields.update({alias: expr.as_mql(self, self.connection) for alias, expr in ordering}) + # Convert defaultdict to dict so it doesn't appear as + # "defaultdict(, ..." in query logging. + return dict(fields) def _get_ordering(self): """ diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index a9c5ab9bc..9012279ea 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -77,7 +77,12 @@ def query(self, compiler, connection): # noqa: ARG001 def ref(self, compiler, connection): # noqa: ARG001 - return f"${self.refs}" + prefix = ( + f"{self.source.alias}." + if isinstance(self.source, Col) and self.source.alias != compiler.collection_name + else "" + ) + return f"${prefix}{self.refs}" def star(self, compiler, connection): # noqa: ARG001 diff --git a/django_mongodb/features.py b/django_mongodb/features.py index a13f6d904..2fa66bf9a 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -49,14 +49,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation", # Length of null considered zero rather than null. "db_functions.text.test_length.LengthTests.test_basic", - # Wrong annotation names in $project (dollar prefixed) when querying - # multiple collections. - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_reverse_m2m", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_with_m2m", - "annotations.tests.NonAggregateAnnotationTestCase.test_chaining_annotation_filter_with_m2m", - "annotations.tests.NonAggregateAnnotationTestCase.test_mti_annotations", - "expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression", - "expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression_flat", # range lookup includes incorrect values. "expressions.tests.IterableLookupInnerExpressionsTests.test_expressions_in_lookups_join_choice", # Unexpected alias_refcount in alias_map. From d9488a120cc8c054dabf49c07118099c830397fc Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 13 Aug 2024 21:35:18 -0300 Subject: [PATCH 2/2] prevent $addFields from appearing as defaultdict in query logging --- django_mongodb/compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index bb1ac4745..7e49b02d9 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -193,7 +193,9 @@ def _build_aggregation_pipeline(self, ids, group): projected_fields[table][field] = value else: projected_fields[key] = value - pipeline.append({"$addFields": projected_fields}) + # Convert defaultdict to dict so it doesn't appear as + # "defaultdict(, ..." in query logging. + pipeline.append({"$addFields": dict(projected_fields)}) if "_id" not in projected_fields: pipeline.append({"$unset": "_id"}) return pipeline