3
3
4
4
from django .core .exceptions import EmptyResultSet , FullResultSet
5
5
from django .db import DatabaseError , IntegrityError , NotSupportedError
6
- from django .db .models .expressions import Case , When
6
+ from django .db .models .expressions import Case , Col , When
7
7
from django .db .models .functions import Mod
8
8
from django .db .models .lookups import Exact
9
9
from django .db .models .sql .constants import INNER
@@ -101,6 +101,7 @@ def join(self, compiler, connection):
101
101
lhs_fields = []
102
102
rhs_fields = []
103
103
# Add a join condition for each pair of joining fields.
104
+ parent_template = "parent__field__"
104
105
for lhs , rhs in self .join_fields :
105
106
lhs , rhs = connection .ops .prepare_join_on_clause (
106
107
self .parent_alias , lhs , compiler .collection_name , rhs
@@ -109,8 +110,28 @@ def join(self, compiler, connection):
109
110
# In the lookup stage, the reference to this column doesn't include
110
111
# the collection name.
111
112
rhs_fields .append (rhs .as_mql (compiler , connection ))
113
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
114
+ if extra :
115
+ columns = []
116
+ for expr in extra .leaves ():
117
+ if hasattr (expr , "lhs" ) and isinstance (expr .lhs , Col ):
118
+ columns .append ((expr .lhs , len (lhs_fields )))
119
+ lhs_fields .append (expr .lhs .as_mql (compiler , connection ))
120
+ if hasattr (expr , "rhs" ) and isinstance (expr .rhs , Col ):
121
+ columns .append ((expr .rhs , None ))
122
+ replacements = {}
123
+ for col , parent_pos in columns :
124
+ column_target = col .copy ()
125
+ if column_target .alias == self .table_name :
126
+ column_target .alias = compiler .collection_name
127
+ else :
128
+ column_target .target .db_column = f"{ parent_template } { parent_pos } "
129
+ column_target .target .set_attributes_from_name (f"{ parent_template } { len (lhs_fields )} " )
130
+ replacements [col ] = column_target
131
+ extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
132
+ else :
133
+ extra_condition = []
112
134
113
- parent_template = "parent__field__"
114
135
lookup_pipeline = [
115
136
{
116
137
"$lookup" : {
@@ -136,6 +157,7 @@ def join(self, compiler, connection):
136
157
{"$eq" : [f"$${ parent_template } { i } " , {"$toString" : field }]}
137
158
for i , field in enumerate (rhs_fields )
138
159
]
160
+ + extra_condition
139
161
}
140
162
}
141
163
}
0 commit comments