Skip to content

Commit 2a135cc

Browse files
committed
Edits.
1 parent 06894c5 commit 2a135cc

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

django_mongodb_backend/compiler.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull, Lookup
12+
from django.db.models.lookups import IsNull
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16-
from django.db.models.sql.where import AND, WhereNode
16+
from django.db.models.sql.where import AND, OR, XOR, WhereNode
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

2020
from .expressions.search import SearchExpression, SearchVector
2121
from .query import MongoQuery, wrap_database_errors
22-
from .query_utils import is_direct_value
22+
from .query_utils import is_constant_value
2323

2424

2525
class SQLCompiler(compiler.SQLCompiler):
@@ -658,22 +658,43 @@ def get_combinator_queries(self):
658658
combinator_pipeline.append({"$unset": "_id"})
659659
return combinator_pipeline
660660

661+
def _get_pushable_conditions(self):
662+
def collect_pushable(expr, negated=False):
663+
if isinstance(expr, WhereNode):
664+
pushable_expressions = (
665+
collect_pushable(sub_expr, negated=negated != expr.negated)
666+
for sub_expr in expr.children
667+
)
668+
operator = expr.connector
669+
if operator == XOR:
670+
return {}
671+
if negated:
672+
operator = OR if operator == AND else AND
673+
result = defaultdict(list, next(pushable_expressions, {}))
674+
shared_alias = set(result)
675+
for pe in pushable_expressions:
676+
shared_alias &= set(pe)
677+
for alias, expressions in pe.items():
678+
result[alias] += expressions
679+
if operator == AND:
680+
return result
681+
return {k: v for k, v in result.items() if k in shared_alias}
682+
if expr.lhs.is_simple_column and (
683+
is_constant_value(expr.rhs) or expr.rhs.is_simple_column
684+
):
685+
alias = expr.lhs.alias
686+
if negated:
687+
expr = WhereNode(children=[expr], negated=True)
688+
return {expr.lhs.alias: [expr]}
689+
return {}
690+
691+
return collect_pushable(self.get_where())
692+
661693
def get_lookup_pipeline(self):
662694
result = []
663695
# To improve join performance, push conditions (filters) from the
664696
# WHERE ($match) clause to the JOIN ($lookup) clause.
665-
where = self.get_where()
666-
pushed_filters = defaultdict(list)
667-
for expr in where.children if where and where.connector == AND else ():
668-
# Push only basic lookups; no subqueries or complex conditions.
669-
# To avoid duplication across subqueries, only use the LHS target
670-
# table.
671-
if (
672-
isinstance(expr, Lookup)
673-
and isinstance(expr.lhs, Col)
674-
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, (Value, Col)))
675-
):
676-
pushed_filters[expr.lhs.alias].append(expr)
697+
pushed_filters = self._get_pushable_conditions()
677698
for alias in tuple(self.query.alias_map):
678699
if not self.query.alias_refcount[alias] or self.collection_name == alias:
679700
continue

tests/queries_/test_mql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def test_filter_on_local_and_nested_join_fields(self):
267267
)
268268

269269
def test_negated_related_filter_is_not_pushable(self):
270+
# import ipdb
271+
# ipdb.set_trace()
270272
with self.assertNumQueries(1) as ctx:
271273
list(Book.objects.filter(~models.Q(author__name="John")))
272274
self.assertAggregateQuery(

0 commit comments

Comments
 (0)