3
3
4
4
from django .core .exceptions import EmptyResultSet , FullResultSet
5
5
from django .db import DatabaseError , IntegrityError
6
- from django .db .models .expressions import Case , When
6
+ from django .db .models .expressions import Case , Col , When
7
7
from django .db .models .functions import Mod
8
8
from django .db .models .lookups import Exact
9
9
from django .db .models .sql .constants import INNER
@@ -97,6 +97,7 @@ def join(self, compiler, connection):
97
97
lhs_fields = []
98
98
rhs_fields = []
99
99
# Add a join condition for each pair of joining fields.
100
+ parent_template = "parent__field__"
100
101
for lhs , rhs in self .join_fields :
101
102
lhs , rhs = connection .ops .prepare_join_on_clause (
102
103
self .parent_alias , lhs , self .table_name , rhs
@@ -105,8 +106,28 @@ def join(self, compiler, connection):
105
106
# In the lookup stage, the reference to this column doesn't include
106
107
# the collection name.
107
108
rhs_fields .append (rhs .as_mql (compiler , connection ).replace (f"{ self .table_name } ." , "" , 1 ))
109
+ extra = self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
110
+ if extra :
111
+ columns = []
112
+ for expr in extra .leaves ():
113
+ if hasattr (expr , "lhs" ) and isinstance (expr .lhs , Col ):
114
+ columns .append ((expr .lhs , len (lhs_fields )))
115
+ lhs_fields .append (expr .lhs .as_mql (compiler , connection ))
116
+ if hasattr (expr , "rhs" ) and isinstance (expr .rhs , Col ):
117
+ columns .append ((expr .rhs , None ))
118
+ replacements = {}
119
+ for col , parent_pos in columns :
120
+ column_target = col .copy ()
121
+ if column_target .alias == self .table_name :
122
+ column_target .alias = compiler .collection_name
123
+ else :
124
+ column_target .target .db_column = f"{ parent_template } { parent_pos } "
125
+ column_target .target .set_attributes_from_name (f"{ parent_template } { len (lhs_fields )} " )
126
+ replacements [col ] = column_target
127
+ extra_condition = [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
128
+ else :
129
+ extra_condition = []
108
130
109
- parent_template = "parent__field__"
110
131
lookup_pipeline = [
111
132
{
112
133
"$lookup" : {
@@ -132,6 +153,7 @@ def join(self, compiler, connection):
132
153
{"$eq" : [f"$${ parent_template } { i } " , field ]}
133
154
for i , field in enumerate (rhs_fields )
134
155
]
156
+ + extra_condition
135
157
}
136
158
}
137
159
}
0 commit comments