@@ -123,25 +123,21 @@ def extra_where(self, compiler, connection): # noqa: ARG001
123
123
raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
124
124
125
125
126
- def join (self , compiler , connection ):
127
- lookup_pipeline = []
128
- lhs_fields = []
129
- rhs_fields = []
130
- # Add a join condition for each pair of joining fields.
126
+ def join (self , compiler , connection , pushed_filter_expression = None ):
127
+ """
128
+ Generate a MongoDB $lookup stage for a join.
129
+
130
+ `pushed_filter_expression` is a Where expression involving fields from the
131
+ joined collection which can be pushed from the WHERE ($match) clause to the
132
+ JOIN ($lookup) clause to improve performance.
133
+ """
131
134
parent_template = "parent__field__"
132
- for lhs , rhs in self .join_fields :
133
- lhs , rhs = connection .ops .prepare_join_on_clause (
134
- self .parent_alias , lhs , compiler .collection_name , rhs
135
- )
136
- lhs_fields .append (lhs .as_mql (compiler , connection ))
137
- # In the lookup stage, the reference to this column doesn't include
138
- # the collection name.
139
- rhs_fields .append (rhs .as_mql (compiler , connection ))
140
- # Handle any join conditions besides matching field pairs.
141
- extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
142
- if extra :
135
+
136
+ def _get_reroot_replacements (expression ):
137
+ if not expression :
138
+ return None
143
139
columns = []
144
- for expr in extra .leaves ():
140
+ for expr in expression .leaves ():
145
141
# Determine whether the column needs to be transformed or rerouted
146
142
# as part of the subquery.
147
143
for hand_side in ["lhs" , "rhs" ]:
@@ -151,27 +147,61 @@ def join(self, compiler, connection):
151
147
# lhs_fields.
152
148
if hand_side_value .alias != self .table_alias :
153
149
pos = len (lhs_fields )
154
- lhs_fields .append (expr . lhs .as_mql (compiler , connection ))
150
+ lhs_fields .append (hand_side_value .as_mql (compiler , connection ))
155
151
else :
156
152
pos = None
157
153
columns .append ((hand_side_value , pos ))
158
154
# Replace columns in the extra conditions with new column references
159
155
# based on their rerouted positions in the join pipeline.
160
156
replacements = {}
161
157
for col , parent_pos in columns :
162
- column_target = Col (compiler .collection_name , expr .output_field .__class__ ())
158
+ target = col .target .clone ()
159
+ target .remote_field = col .target .remote_field
160
+ column_target = Col (compiler .collection_name , target )
163
161
if parent_pos is not None :
164
162
target_col = f"${ parent_template } { parent_pos } "
165
163
column_target .target .db_column = target_col
166
164
column_target .target .set_attributes_from_name (target_col )
167
165
else :
168
166
column_target .target = col .target
169
167
replacements [col ] = column_target
170
- # Apply the transformed expressions in the extra condition.
171
- extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
172
- else :
173
- extra_condition = []
168
+ return replacements
174
169
170
+ lookup_pipeline = []
171
+ lhs_fields = []
172
+ rhs_fields = []
173
+ # Add a join condition for each pair of joining fields.
174
+ for lhs , rhs in self .join_fields :
175
+ lhs , rhs = connection .ops .prepare_join_on_clause (
176
+ self .parent_alias , lhs , compiler .collection_name , rhs
177
+ )
178
+ lhs_fields .append (lhs .as_mql (compiler , connection ))
179
+ # In the lookup stage, the reference to this column doesn't include the
180
+ # collection name.
181
+ rhs_fields .append (rhs .as_mql (compiler , connection ))
182
+ # Handle any join conditions besides matching field pairs.
183
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
184
+ extra_conditions = []
185
+ if extra :
186
+ replacements = _get_reroot_replacements (extra )
187
+ extra_conditions .append (
188
+ extra .replace_expressions (replacements ).as_mql (compiler , connection )
189
+ )
190
+ # pushed_filter_expression is a Where expression from the outer WHERE
191
+ # clause that involves fields from the joined (right-hand) table and
192
+ # possibly the outer (left-hand) table. If it can be safely evaluated
193
+ # within the $lookup pipeline (e.g., field comparisons like
194
+ # right.status = left.id), it is "pushed" into the join's $match stage to
195
+ # reduce the volume of joined documents. This only applies to INNER JOINs,
196
+ # as pushing filters into a LEFT JOIN can change the semantics of the
197
+ # result. LEFT JOINs may rely on null checks to detect missing RHS.
198
+ if pushed_filter_expression and self .join_type == INNER :
199
+ rerooted_replacement = _get_reroot_replacements (pushed_filter_expression )
200
+ extra_conditions .append (
201
+ pushed_filter_expression .replace_expressions (rerooted_replacement ).as_mql (
202
+ compiler , connection
203
+ )
204
+ )
175
205
lookup_pipeline = [
176
206
{
177
207
"$lookup" : {
@@ -197,7 +227,7 @@ def join(self, compiler, connection):
197
227
{"$eq" : [f"$${ parent_template } { i } " , field ]}
198
228
for i , field in enumerate (rhs_fields )
199
229
]
200
- + extra_condition
230
+ + extra_conditions
201
231
}
202
232
}
203
233
}
0 commit comments