Skip to content

Commit d590103

Browse files
committed
Push simple filter conditions into $lookup stage.
1 parent 0deeb1d commit d590103

File tree

3 files changed

+53
-23
lines changed

3 files changed

+53
-23
lines changed

django_mongodb_backend/compiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull
12+
from django.db.models.lookups import IsNull, Lookup
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16+
from django.db.models.sql.where import AND
1617
from django.utils.functional import cached_property
1718
from pymongo import ASCENDING, DESCENDING
1819

1920
from .query import MongoQuery, wrap_database_errors
21+
from .query_utils import is_direct_value
2022

2123

2224
class SQLCompiler(compiler.SQLCompiler):
@@ -548,10 +550,22 @@ def get_combinator_queries(self):
548550

549551
def get_lookup_pipeline(self):
550552
result = []
553+
where = self.get_where()
554+
promote_filters = defaultdict(list)
555+
for expr in where.children if where and where.connector == AND else ():
556+
if (
557+
isinstance(expr, Lookup)
558+
and isinstance(expr.lhs, Col)
559+
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value))
560+
):
561+
promote_filters[expr.lhs.alias].append(expr)
562+
551563
for alias in tuple(self.query.alias_map):
552564
if not self.query.alias_refcount[alias] or self.collection_name == alias:
553565
continue
554-
result += self.query.alias_map[alias].as_mql(self, self.connection)
566+
result += self.query.alias_map[alias].as_mql(
567+
self, self.connection, promote_filters[alias]
568+
)
555569
return result
556570

557571
def _get_aggregate_expressions(self, expr):

django_mongodb_backend/query.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
123123
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
124124

125125

126-
def join(self, compiler, connection):
127-
lookup_pipeline = []
128-
lhs_fields = []
129-
rhs_fields = []
130-
# Add a join condition for each pair of joining fields.
131-
parent_template = "parent__field__"
132-
for lhs, rhs in self.join_fields:
133-
lhs, rhs = connection.ops.prepare_join_on_clause(
134-
self.parent_alias, lhs, compiler.collection_name, rhs
135-
)
136-
lhs_fields.append(lhs.as_mql(compiler, connection))
137-
# In the lookup stage, the reference to this column doesn't include
138-
# the collection name.
139-
rhs_fields.append(rhs.as_mql(compiler, connection))
140-
# Handle any join conditions besides matching field pairs.
141-
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
142-
if extra:
126+
def join(self, compiler, connection, pushed_expressions=None):
127+
def _get_reroot_replacements(expressions):
128+
if not expressions:
129+
return []
143130
columns = []
144-
for expr in extra.leaves():
131+
for expr in expressions:
145132
# Determine whether the column needs to be transformed or rerouted
146133
# as part of the subquery.
147134
for hand_side in ["lhs", "rhs"]:
@@ -159,18 +146,45 @@ def join(self, compiler, connection):
159146
# based on their rerouted positions in the join pipeline.
160147
replacements = {}
161148
for col, parent_pos in columns:
162-
column_target = Col(compiler.collection_name, expr.output_field.__class__())
149+
column_target = Col(compiler.collection_name, col.target, col.output_field)
163150
if parent_pos is not None:
164151
target_col = f"${parent_template}{parent_pos}"
165152
column_target.target.db_column = target_col
166153
column_target.target.set_attributes_from_name(target_col)
167154
else:
168155
column_target.target = col.target
169156
replacements[col] = column_target
170-
# Apply the transformed expressions in the extra condition.
157+
return replacements
158+
159+
lookup_pipeline = []
160+
lhs_fields = []
161+
rhs_fields = []
162+
# Add a join condition for each pair of joining fields.
163+
parent_template = "parent__field__"
164+
for lhs, rhs in self.join_fields:
165+
lhs, rhs = connection.ops.prepare_join_on_clause(
166+
self.parent_alias, lhs, compiler.collection_name, rhs
167+
)
168+
lhs_fields.append(lhs.as_mql(compiler, connection))
169+
# In the lookup stage, the reference to this column doesn't include
170+
# the collection name.
171+
rhs_fields.append(rhs.as_mql(compiler, connection))
172+
# Handle any join conditions besides matching field pairs.
173+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
174+
175+
if extra:
176+
replacements = _get_reroot_replacements(extra.leaves())
171177
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
172178
else:
173179
extra_condition = []
180+
if self.join_type == INNER:
181+
rerooted_replacement = _get_reroot_replacements(pushed_expressions)
182+
resolved_pushed_expressions = [
183+
expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection)
184+
for expr in pushed_expressions
185+
]
186+
else:
187+
resolved_pushed_expressions = []
174188

175189
lookup_pipeline = [
176190
{
@@ -198,6 +212,7 @@ def join(self, compiler, connection):
198212
for i, field in enumerate(rhs_fields)
199213
]
200214
+ extra_condition
215+
+ resolved_pushed_expressions
201216
}
202217
}
203218
}

tests/queries_/test_mql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_join(self):
2020
"{'$lookup': {'from': 'queries__author', "
2121
"'let': {'parent__field__0': '$author_id'}, "
2222
"'pipeline': [{'$match': {'$expr': "
23-
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}]}}}], 'as': 'queries__author'}}, "
23+
"{'$and': [{'$eq': ['$$parent__field__0', '$_id']}, "
24+
"{'$eq': ['$name', 'Bob']}]}}}], 'as': 'queries__author'}}, "
2425
"{'$unwind': '$queries__author'}, "
2526
"{'$match': {'$expr': {'$eq': ['$queries__author.name', 'Bob']}}}])",
2627
)

0 commit comments

Comments
 (0)