6
6
from django .core .exceptions import EmptyResultSet , FullResultSet
7
7
from django .db import NotSupportedError
8
8
from django .db .models .expressions import (
9
+ BaseExpression ,
9
10
Case ,
10
11
Col ,
11
12
ColPairs ,
12
13
CombinedExpression ,
13
14
Exists ,
14
15
ExpressionList ,
15
16
ExpressionWrapper ,
17
+ Func ,
16
18
NegatedExpression ,
17
19
OrderBy ,
18
20
RawSQL ,
23
25
Value ,
24
26
When ,
25
27
)
28
+ from django .db .models .fields .json import KeyTransform
26
29
from django .db .models .sql import Query
27
30
28
- from .. query_utils import process_lhs
31
+ from django_mongodb_backend . fields . array import Array
29
32
33
+ from ..query_utils import is_direct_value , process_lhs
30
34
31
- def case (self , compiler , connection ):
35
+
36
+ def case (self , compiler , connection , as_path = False ):
32
37
case_parts = []
33
38
for case in self .cases :
34
39
case_mql = {}
35
40
try :
36
- case_mql ["case" ] = case .as_mql (compiler , connection )
41
+ case_mql ["case" ] = case .as_mql (compiler , connection , as_path = False )
37
42
except EmptyResultSet :
38
43
continue
39
44
except FullResultSet :
@@ -45,12 +50,16 @@ def case(self, compiler, connection):
45
50
default_mql = self .default .as_mql (compiler , connection )
46
51
if not case_parts :
47
52
return default_mql
48
- return {
53
+ expr = {
49
54
"$switch" : {
50
55
"branches" : case_parts ,
51
56
"default" : default_mql ,
52
57
}
53
58
}
59
+ if as_path :
60
+ return {"$expr" : expr }
61
+
62
+ return expr
54
63
55
64
56
65
def col (self , compiler , connection , as_path = False ): # noqa: ARG001
@@ -76,34 +85,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
76
85
return f"{ prefix } { self .target .column } "
77
86
78
87
79
- def col_pairs (self , compiler , connection ):
88
+ def col_pairs (self , compiler , connection , as_path = False ):
80
89
cols = self .get_cols ()
81
90
if len (cols ) > 1 :
82
91
raise NotSupportedError ("ColPairs is not supported." )
83
- return cols [0 ].as_mql (compiler , connection )
92
+ return cols [0 ].as_mql (compiler , connection , as_path = as_path )
84
93
85
94
86
- def combined_expression (self , compiler , connection ):
95
+ def combined_expression (self , compiler , connection , as_path = False ):
87
96
expressions = [
88
- self .lhs .as_mql (compiler , connection ),
89
- self .rhs .as_mql (compiler , connection ),
97
+ self .lhs .as_mql (compiler , connection , as_path = as_path ),
98
+ self .rhs .as_mql (compiler , connection , as_path = as_path ),
90
99
]
91
100
return connection .ops .combine_expression (self .connector , expressions )
92
101
93
102
94
- def expression_wrapper (self , compiler , connection ):
95
- return self .expression .as_mql (compiler , connection )
103
+ def expression_wrapper (self , compiler , connection , as_path = False ):
104
+ return self .expression .as_mql (compiler , connection , as_path = as_path )
96
105
97
106
98
- def negated_expression (self , compiler , connection ):
99
- return {"$not" : expression_wrapper (self , compiler , connection )}
107
+ def negated_expression (self , compiler , connection , as_path = False ):
108
+ return {"$not" : expression_wrapper (self , compiler , connection , as_path = as_path )}
100
109
101
110
102
111
def order_by (self , compiler , connection ):
103
112
return self .expression .as_mql (compiler , connection )
104
113
105
114
106
- def query (self , compiler , connection , get_wrapping_pipeline = None ):
115
+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
107
116
subquery_compiler = self .get_compiler (connection = connection )
108
117
subquery_compiler .pre_sql_setup (with_col_aliases = False )
109
118
field_name , expr = subquery_compiler .columns [0 ]
@@ -145,14 +154,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145
154
# Erase project_fields since the required value is projected above.
146
155
subquery .project_fields = None
147
156
compiler .subqueries .append (subquery )
157
+ if as_path :
158
+ return f"{ table_output } .{ field_name } "
148
159
return f"${ table_output } .{ field_name } "
149
160
150
161
151
162
def raw_sql (self , compiler , connection ): # noqa: ARG001
152
163
raise NotSupportedError ("RawSQL is not supported on MongoDB." )
153
164
154
165
155
- def ref (self , compiler , connection ): # noqa: ARG001
166
+ def ref (self , compiler , connection , as_path = False ): # noqa: ARG001
156
167
prefix = (
157
168
f"{ self .source .alias } ."
158
169
if isinstance (self .source , Col ) and self .source .alias != compiler .collection_name
@@ -162,32 +173,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162
173
refs , _ = compiler .columns [self .ordinal - 1 ]
163
174
else :
164
175
refs = self .refs
165
- return f"${ prefix } { refs } "
176
+ if not as_path :
177
+ prefix = f"${ prefix } "
178
+ return f"{ prefix } { refs } "
166
179
167
180
168
- def star (self , compiler , connection ): # noqa: ARG001
181
+ def star (self , compiler , connection , ** extra ): # noqa: ARG001
169
182
return {"$literal" : True }
170
183
171
184
172
- def subquery (self , compiler , connection , get_wrapping_pipeline = None ):
173
- return self .query .as_mql (compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
185
+ def subquery (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
186
+ expr = self .query .as_mql (
187
+ compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_path = False
188
+ )
189
+ if as_path :
190
+ return {"$expr" : expr }
191
+ return expr
174
192
175
193
176
- def exists (self , compiler , connection , get_wrapping_pipeline = None ):
194
+ def exists (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
177
195
try :
178
- lhs_mql = subquery (self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
196
+ lhs_mql = subquery (
197
+ self ,
198
+ compiler ,
199
+ connection ,
200
+ get_wrapping_pipeline = get_wrapping_pipeline ,
201
+ as_path = as_path ,
202
+ )
179
203
except EmptyResultSet :
180
204
return Value (False ).as_mql (compiler , connection )
181
- return connection .mongo_operators ["isnull" ](lhs_mql , False )
205
+ if as_path :
206
+ return {"$expr" : connection .mongo_operators_match ["isnull" ](lhs_mql , False )}
207
+ return connection .mongo_operators_expr ["isnull" ](lhs_mql , False )
182
208
183
209
184
- def when (self , compiler , connection ):
185
- return self .condition .as_mql (compiler , connection )
210
+ def when (self , compiler , connection , as_path = False ):
211
+ return self .condition .as_mql (compiler , connection , as_path = as_path )
186
212
187
213
188
- def value (self , compiler , connection ): # noqa: ARG001
214
+ def value (self , compiler , connection , as_path = False ): # noqa: ARG001
189
215
value = self .value
190
- if isinstance (value , (list , int )):
216
+ if isinstance (value , (list , int )) and not as_path :
191
217
# Wrap lists & numbers in $literal to prevent ambiguity when Value
192
218
# appears in $project.
193
219
return {"$literal" : value }
@@ -209,6 +235,36 @@ def value(self, compiler, connection): # noqa: ARG001
209
235
return value
210
236
211
237
238
+ @staticmethod
239
+ def _is_constant_value (value ):
240
+ if isinstance (value , list | Array ):
241
+ iterable = value .get_source_expressions () if isinstance (value , Array ) else value
242
+ return all (_is_constant_value (e ) for e in iterable )
243
+ if is_direct_value (value ):
244
+ return True
245
+ return isinstance (value , Func | Value ) and not (
246
+ value .contains_aggregate
247
+ or value .contains_over_clause
248
+ or value .contains_column_references
249
+ or value .contains_subquery
250
+ )
251
+
252
+
253
+ @staticmethod
254
+ def _is_simple_column (lhs ):
255
+ while isinstance (lhs , KeyTransform ):
256
+ if "." in getattr (lhs , "key_name" , "" ):
257
+ return False
258
+ lhs = lhs .lhs
259
+ col = lhs .source if isinstance (lhs , Ref ) else lhs
260
+ # Foreign columns from parent cannot be addressed as single match
261
+ return isinstance (col , Col ) and col .alias is not None
262
+
263
+
264
+ def _is_simple_expression (self ):
265
+ return self .is_simple_column (self .lhs ) and self .is_constant_value (self .rhs )
266
+
267
+
212
268
def register_expressions ():
213
269
Case .as_mql = case
214
270
Col .as_mql = col
@@ -227,3 +283,6 @@ def register_expressions():
227
283
Subquery .as_mql = subquery
228
284
When .as_mql = when
229
285
Value .as_mql = value
286
+ BaseExpression .is_simple_expression = _is_simple_expression
287
+ BaseExpression .is_simple_column = _is_simple_column
288
+ BaseExpression .is_constant_value = _is_constant_value
0 commit comments