diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1c20faa18..38d632a72 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -9,17 +9,17 @@ from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power -from django.db.models.lookups import IsNull, Lookup +from django.db.models.lookups import IsNull from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE from django.db.models.sql.datastructures import BaseTable -from django.db.models.sql.where import AND, WhereNode +from django.db.models.sql.where import AND, OR, XOR, NothingNode, WhereNode from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING from .expressions.search import SearchExpression, SearchVector from .query import MongoQuery, wrap_database_errors -from .query_utils import is_direct_value +from .query_utils import is_constant_value class SQLCompiler(compiler.SQLCompiler): @@ -658,27 +658,61 @@ def get_combinator_queries(self): combinator_pipeline.append({"$unset": "_id"}) return combinator_pipeline + def _get_pushable_conditions(self): + def collect_pushable(expr, negated=False): + if expr is None or isinstance(expr, NothingNode): + return {} + if isinstance(expr, WhereNode): + negated ^= expr.negated + pushable_expressions = [ + collect_pushable(sub_expr, negated=negated) + for sub_expr in expr.children + if sub_expr is not None + ] + operator = expr.connector + if operator == XOR: + return {} + if negated: + operator = OR if operator == AND else AND + alias_children = defaultdict(list) + for pe in pushable_expressions: + for alias, expressions in pe.items(): + alias_children[alias].append(expressions) + result = {} + for alias, children in alias_children.items(): + result[alias] = WhereNode( + children=children, + negated=False, + connector=operator, + ) + if operator == AND: + return result + shared_alias = ( + set.intersection(*(set(pe) for pe in pushable_expressions)) + if pushable_expressions + else set() + ) + return {k: v for k, v in result.items() if k in shared_alias} + if isinstance(expr.lhs, Col) and ( + is_constant_value(expr.rhs) or getattr(expr.rhs, "is_simple_column", False) + ): + alias = expr.lhs.alias + expr = WhereNode(children=[expr], negated=negated) + return {alias: expr} + return {} + + return collect_pushable(self.get_where()) + def get_lookup_pipeline(self): result = [] # To improve join performance, push conditions (filters) from the # WHERE ($match) clause to the JOIN ($lookup) clause. - where = self.get_where() - pushed_filters = defaultdict(list) - for expr in where.children if where and where.connector == AND else (): - # Push only basic lookups; no subqueries or complex conditions. - # To avoid duplication across subqueries, only use the LHS target - # table. - if ( - isinstance(expr, Lookup) - and isinstance(expr.lhs, Col) - and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col))) - ): - pushed_filters[expr.lhs.alias].append(expr) + pushed_filters = self._get_pushable_conditions() for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias] or self.collection_name == alias: continue result += self.query.alias_map[alias].as_mql( - self, self.connection, WhereNode(pushed_filters[alias], connector=AND) + self, self.connection, pushed_filters.get(alias) ) return result diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index fa04feb75..ea892ec9f 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,4 +1,5 @@ from django.core.exceptions import FullResultSet +from django.db.models import F from django.db.models.aggregates import Aggregate from django.db.models.expressions import CombinedExpression, Func, Value from django.db.models.sql.query import Query @@ -67,7 +68,7 @@ def is_constant_value(value): else: constants_sub_expressions = True constants_sub_expressions = constants_sub_expressions and not ( - isinstance(value, Query) + isinstance(value, Query | F) or value.contains_aggregate or value.contains_over_clause or value.contains_column_references diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index ee35b4e21..094df7e7d 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -6,6 +6,28 @@ class MongoTestCaseMixin: maxDiff = None + COMMUTATIVE_OPERATORS = {"$and", "$or", "$all"} + + @staticmethod + def _normalize_query(obj): + if isinstance(obj, dict): + normalized = {} + for k, v in obj.items(): + if k in MongoTestCaseMixin.COMMUTATIVE_OPERATORS and isinstance(v, list): + # Only sort for commutative operators + normalized[k] = sorted( + (MongoTestCaseMixin._normalize_query(i) for i in v), key=lambda x: str(x) + ) + else: + normalized[k] = MongoTestCaseMixin._normalize_query(v) + return normalized + + if isinstance(obj, list): + # Lists not under commutative ops keep their order + return [MongoTestCaseMixin._normalize_query(i) for i in obj] + + return obj + def assertAggregateQuery(self, query, expected_collection, expected_pipeline): """ Assert that the logged query is equal to: @@ -16,6 +38,10 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): self.assertEqual(operator, "aggregate") self.assertEqual(collection, expected_collection) self.assertEqual( - eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 - expected_pipeline, + self._normalize_query( + eval( # noqa: S307 + pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {} + ) + ), + self._normalize_query(expected_pipeline), ) diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index e8837bf8a..9d061a106 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -281,7 +281,14 @@ def test_negated_related_filter_is_not_pushable(self): "pipeline": [ { "$match": { - "$expr": {"$and": [{"$eq": ["$$parent__field__0", "$_id"]}]} + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"$nor": [{"name": "John"}]}, + ] } } ],