Skip to content

Commit 9831e9b

Browse files
WaVEVtimgraham
authored andcommitted
INTPYTHON-635 Improve join performance by pushing simple filter conditions to $lookup
1 parent 66b2eb0 commit 9831e9b

File tree

5 files changed

+640
-28
lines changed

5 files changed

+640
-28
lines changed

django_mongodb_backend/compiler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1010
from django.db.models.functions.comparison import Coalesce
1111
from django.db.models.functions.math import Power
12-
from django.db.models.lookups import IsNull
12+
from django.db.models.lookups import IsNull, Lookup
1313
from django.db.models.sql import compiler
1414
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
1515
from django.db.models.sql.datastructures import BaseTable
16+
from django.db.models.sql.where import AND, WhereNode
1617
from django.utils.functional import cached_property
1718
from pymongo import ASCENDING, DESCENDING
1819

1920
from .query import MongoQuery, wrap_database_errors
21+
from .query_utils import is_direct_value
2022

2123

2224
class SQLCompiler(compiler.SQLCompiler):
@@ -548,10 +550,26 @@ def get_combinator_queries(self):
548550

549551
def get_lookup_pipeline(self):
550552
result = []
553+
# To improve join performance, push conditions (filters) from the
554+
# WHERE ($match) clause to the JOIN ($lookup) clause.
555+
where = self.get_where()
556+
pushed_filters = defaultdict(list)
557+
for expr in where.children if where and where.connector == AND else ():
558+
# Push only basic lookups; no subqueries or complex conditions.
559+
# To avoid duplication across subqueries, only use the LHS target
560+
# table.
561+
if (
562+
isinstance(expr, Lookup)
563+
and isinstance(expr.lhs, Col)
564+
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value | Col))
565+
):
566+
pushed_filters[expr.lhs.alias].append(expr)
551567
for alias in tuple(self.query.alias_map):
552568
if not self.query.alias_refcount[alias] or self.collection_name == alias:
553569
continue
554-
result += self.query.alias_map[alias].as_mql(self, self.connection)
570+
result += self.query.alias_map[alias].as_mql(
571+
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
572+
)
555573
return result
556574

557575
def _get_aggregate_expressions(self, expr):

django_mongodb_backend/query.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,21 @@ def extra_where(self, compiler, connection): # noqa: ARG001
123123
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
124124

125125

126-
def join(self, compiler, connection):
127-
lookup_pipeline = []
128-
lhs_fields = []
129-
rhs_fields = []
130-
# Add a join condition for each pair of joining fields.
126+
def join(self, compiler, connection, pushed_filter_expression=None):
127+
"""
128+
Generate a MongoDB $lookup stage for a join.
129+
130+
`pushed_filter_expression` is a Where expression involving fields from the
131+
joined collection which can be pushed from the WHERE ($match) clause to the
132+
JOIN ($lookup) clause to improve performance.
133+
"""
131134
parent_template = "parent__field__"
132-
for lhs, rhs in self.join_fields:
133-
lhs, rhs = connection.ops.prepare_join_on_clause(
134-
self.parent_alias, lhs, compiler.collection_name, rhs
135-
)
136-
lhs_fields.append(lhs.as_mql(compiler, connection))
137-
# In the lookup stage, the reference to this column doesn't include
138-
# the collection name.
139-
rhs_fields.append(rhs.as_mql(compiler, connection))
140-
# Handle any join conditions besides matching field pairs.
141-
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
142-
if extra:
135+
136+
def _get_reroot_replacements(expression):
137+
if not expression:
138+
return None
143139
columns = []
144-
for expr in extra.leaves():
140+
for expr in expression.leaves():
145141
# Determine whether the column needs to be transformed or rerouted
146142
# as part of the subquery.
147143
for hand_side in ["lhs", "rhs"]:
@@ -151,27 +147,61 @@ def join(self, compiler, connection):
151147
# lhs_fields.
152148
if hand_side_value.alias != self.table_alias:
153149
pos = len(lhs_fields)
154-
lhs_fields.append(expr.lhs.as_mql(compiler, connection))
150+
lhs_fields.append(hand_side_value.as_mql(compiler, connection))
155151
else:
156152
pos = None
157153
columns.append((hand_side_value, pos))
158154
# Replace columns in the extra conditions with new column references
159155
# based on their rerouted positions in the join pipeline.
160156
replacements = {}
161157
for col, parent_pos in columns:
162-
column_target = Col(compiler.collection_name, expr.output_field.__class__())
158+
target = col.target.clone()
159+
target.remote_field = col.target.remote_field
160+
column_target = Col(compiler.collection_name, target)
163161
if parent_pos is not None:
164162
target_col = f"${parent_template}{parent_pos}"
165163
column_target.target.db_column = target_col
166164
column_target.target.set_attributes_from_name(target_col)
167165
else:
168166
column_target.target = col.target
169167
replacements[col] = column_target
170-
# Apply the transformed expressions in the extra condition.
171-
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
172-
else:
173-
extra_condition = []
168+
return replacements
174169

170+
lookup_pipeline = []
171+
lhs_fields = []
172+
rhs_fields = []
173+
# Add a join condition for each pair of joining fields.
174+
for lhs, rhs in self.join_fields:
175+
lhs, rhs = connection.ops.prepare_join_on_clause(
176+
self.parent_alias, lhs, compiler.collection_name, rhs
177+
)
178+
lhs_fields.append(lhs.as_mql(compiler, connection))
179+
# In the lookup stage, the reference to this column doesn't include the
180+
# collection name.
181+
rhs_fields.append(rhs.as_mql(compiler, connection))
182+
# Handle any join conditions besides matching field pairs.
183+
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
184+
extra_conditions = []
185+
if extra:
186+
replacements = _get_reroot_replacements(extra)
187+
extra_conditions.append(
188+
extra.replace_expressions(replacements).as_mql(compiler, connection)
189+
)
190+
# pushed_filter_expression is a Where expression from the outer WHERE
191+
# clause that involves fields from the joined (right-hand) table and
192+
# possibly the outer (left-hand) table. If it can be safely evaluated
193+
# within the $lookup pipeline (e.g., field comparisons like
194+
# right.status = left.id), it is "pushed" into the join's $match stage to
195+
# reduce the volume of joined documents. This only applies to INNER JOINs,
196+
# as pushing filters into a LEFT JOIN can change the semantics of the
197+
# result. LEFT JOINs may rely on null checks to detect missing RHS.
198+
if pushed_filter_expression and self.join_type == INNER:
199+
rerooted_replacement = _get_reroot_replacements(pushed_filter_expression)
200+
extra_conditions.append(
201+
pushed_filter_expression.replace_expressions(rerooted_replacement).as_mql(
202+
compiler, connection
203+
)
204+
)
175205
lookup_pipeline = [
176206
{
177207
"$lookup": {
@@ -197,7 +227,7 @@ def join(self, compiler, connection):
197227
{"$eq": [f"$${parent_template}{i}", field]}
198228
for i, field in enumerate(rhs_fields)
199229
]
200-
+ extra_condition
230+
+ extra_conditions
201231
}
202232
}
203233
}

docs/source/releases/5.2.x.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ Bug fixes
3232
databases.
3333
- :meth:`QuerySet.explain() <django.db.models.query.QuerySet.explain>` now
3434
:ref:`returns a string that can be parsed as JSON <queryset-explain>`.
35+
36+
Performance improvements
37+
------------------------
38+
3539
- Improved ``QuerySet`` performance by removing low limit on server-side chunking.
40+
- Improved ``QuerySet`` join (``$lookup``) performance by pushing some simple
41+
conditions from the ``WHERE`` (``$match``) clause to the ``$lookup`` stage.
3642

3743
5.2.0 beta 1
3844
============

tests/queries_/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,18 @@ class Meta:
5353

5454
def __str__(self):
5555
return str(self.pk)
56+
57+
58+
class Reader(models.Model):
59+
name = models.CharField(max_length=20)
60+
61+
def __str__(self):
62+
return self.name
63+
64+
65+
class Library(models.Model):
66+
name = models.CharField(max_length=20)
67+
readers = models.ManyToManyField(Reader, related_name="libraries")
68+
69+
def __str__(self):
70+
return self.name

0 commit comments

Comments
 (0)