3
3
4
4
from django .core .exceptions import EmptyResultSet , FullResultSet
5
5
from django .db import DatabaseError , IntegrityError , NotSupportedError
6
- from django .db .models .expressions import Case , When
6
+ from django .db .models .expressions import Case , Col , When
7
7
from django .db .models .functions import Mod
8
8
from django .db .models .lookups import Exact
9
9
from django .db .models .sql .constants import INNER
@@ -105,6 +105,7 @@ def join(self, compiler, connection):
105
105
lhs_fields = []
106
106
rhs_fields = []
107
107
# Add a join condition for each pair of joining fields.
108
+ parent_template = "parent__field__"
108
109
for lhs , rhs in self .join_fields :
109
110
lhs , rhs = connection .ops .prepare_join_on_clause (
110
111
self .parent_alias , lhs , compiler .collection_name , rhs
@@ -113,8 +114,41 @@ def join(self, compiler, connection):
113
114
# In the lookup stage, the reference to this column doesn't include
114
115
# the collection name.
115
116
rhs_fields .append (rhs .as_mql (compiler , connection ))
117
+ # Handle any join conditions besides matching field pairs.
118
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
119
+ if extra :
120
+ columns = []
121
+ for expr in extra .leaves ():
122
+ # Determine whether the column needs to be transformed or rerouted
123
+ # as part of the subquery.
124
+ for hand_side in ["lhs" , "rhs" ]:
125
+ hand_side_value = getattr (expr , hand_side , None )
126
+ if isinstance (hand_side_value , Col ):
127
+ # If the column is not part of the joined table, add it to
128
+ # lhs_fields.
129
+ if hand_side_value .alias != self .table_name :
130
+ pos = len (lhs_fields )
131
+ lhs_fields .append (expr .lhs .as_mql (compiler , connection ))
132
+ else :
133
+ pos = None
134
+ columns .append ((hand_side_value , pos ))
135
+ # Replace columns in the extra conditions with new column references
136
+ # based on their rerouted positions in the join pipeline.
137
+ replacements = {}
138
+ for col , parent_pos in columns :
139
+ column_target = Col (compiler .collection_name , expr .output_field .__class__ ())
140
+ if parent_pos is not None :
141
+ target_col = f"${ parent_template } { parent_pos } "
142
+ column_target .target .db_column = target_col
143
+ column_target .target .set_attributes_from_name (target_col )
144
+ else :
145
+ column_target .target = col .target
146
+ replacements [col ] = column_target
147
+ # Apply the transformed expressions in the extra condition.
148
+ extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
149
+ else :
150
+ extra_condition = []
116
151
117
- parent_template = "parent__field__"
118
152
lookup_pipeline = [
119
153
{
120
154
"$lookup" : {
@@ -140,6 +174,7 @@ def join(self, compiler, connection):
140
174
{"$eq" : [f"$${ parent_template } { i } " , field ]}
141
175
for i , field in enumerate (rhs_fields )
142
176
]
177
+ + extra_condition
143
178
}
144
179
}
145
180
}
0 commit comments