@@ -123,25 +123,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001
123
123
raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
124
124
125
125
126
- def join (self , compiler , connection ):
127
- lookup_pipeline = []
128
- lhs_fields = []
129
- rhs_fields = []
130
- # Add a join condition for each pair of joining fields.
131
- parent_template = "parent__field__"
132
- for lhs , rhs in self .join_fields :
133
- lhs , rhs = connection .ops .prepare_join_on_clause (
134
- self .parent_alias , lhs , compiler .collection_name , rhs
135
- )
136
- lhs_fields .append (lhs .as_mql (compiler , connection ))
137
- # In the lookup stage, the reference to this column doesn't include
138
- # the collection name.
139
- rhs_fields .append (rhs .as_mql (compiler , connection ))
140
- # Handle any join conditions besides matching field pairs.
141
- extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
142
- if extra :
126
+ def join (self , compiler , connection , pushed_expressions = None ):
127
+ def _get_reroot_replacements (expressions ):
128
+ if not expressions :
129
+ return []
143
130
columns = []
144
- for expr in extra . leaves () :
131
+ for expr in expressions :
145
132
# Determine whether the column needs to be transformed or rerouted
146
133
# as part of the subquery.
147
134
for hand_side in ["lhs" , "rhs" ]:
@@ -159,18 +146,45 @@ def join(self, compiler, connection):
159
146
# based on their rerouted positions in the join pipeline.
160
147
replacements = {}
161
148
for col , parent_pos in columns :
162
- column_target = Col (compiler .collection_name , expr . output_field . __class__ () )
149
+ column_target = Col (compiler .collection_name , col . target , col . output_field )
163
150
if parent_pos is not None :
164
151
target_col = f"${ parent_template } { parent_pos } "
165
152
column_target .target .db_column = target_col
166
153
column_target .target .set_attributes_from_name (target_col )
167
154
else :
168
155
column_target .target = col .target
169
156
replacements [col ] = column_target
170
- # Apply the transformed expressions in the extra condition.
157
+ return replacements
158
+
159
+ lookup_pipeline = []
160
+ lhs_fields = []
161
+ rhs_fields = []
162
+ # Add a join condition for each pair of joining fields.
163
+ parent_template = "parent__field__"
164
+ for lhs , rhs in self .join_fields :
165
+ lhs , rhs = connection .ops .prepare_join_on_clause (
166
+ self .parent_alias , lhs , compiler .collection_name , rhs
167
+ )
168
+ lhs_fields .append (lhs .as_mql (compiler , connection ))
169
+ # In the lookup stage, the reference to this column doesn't include
170
+ # the collection name.
171
+ rhs_fields .append (rhs .as_mql (compiler , connection ))
172
+ # Handle any join conditions besides matching field pairs.
173
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
174
+
175
+ if extra :
176
+ replacements = _get_reroot_replacements (extra .leaves ())
171
177
extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
172
178
else :
173
179
extra_condition = []
180
+ if self .join_type == INNER :
181
+ rerooted_replacement = _get_reroot_replacements (pushed_expressions )
182
+ resolved_pushed_expressions = [
183
+ expr .replace_expressions (rerooted_replacement ).as_mql (compiler , connection )
184
+ for expr in pushed_expressions
185
+ ]
186
+ else :
187
+ resolved_pushed_expressions = []
174
188
175
189
lookup_pipeline = [
176
190
{
@@ -198,6 +212,7 @@ def join(self, compiler, connection):
198
212
for i , field in enumerate (rhs_fields )
199
213
]
200
214
+ extra_condition
215
+ + resolved_pushed_expressions
201
216
}
202
217
}
203
218
}
0 commit comments