6
6
from django .core .exceptions import EmptyResultSet , FullResultSet
7
7
from django .db import NotSupportedError
8
8
from django .db .models .expressions import (
9
+ BaseExpression ,
9
10
Case ,
10
11
Col ,
11
12
ColPairs ,
12
13
CombinedExpression ,
13
14
Exists ,
14
15
ExpressionList ,
15
16
ExpressionWrapper ,
17
+ Func ,
16
18
NegatedExpression ,
17
19
OrderBy ,
18
20
RawSQL ,
23
25
Value ,
24
26
When ,
25
27
)
28
+ from django .db .models .fields .json import KeyTransform
26
29
from django .db .models .sql import Query
27
30
28
- from django_mongodb_backend .query_utils import process_lhs
31
+ from django_mongodb_backend .fields .array import Array
32
+ from django_mongodb_backend .query_utils import is_direct_value , process_lhs
29
33
30
34
31
- def case (self , compiler , connection ):
35
+ def case (self , compiler , connection , as_path = False ):
32
36
case_parts = []
33
37
for case in self .cases :
34
38
case_mql = {}
35
39
try :
36
- case_mql ["case" ] = case .as_mql (compiler , connection )
40
+ case_mql ["case" ] = case .as_mql (compiler , connection , as_path = False )
37
41
except EmptyResultSet :
38
42
continue
39
43
except FullResultSet :
@@ -45,12 +49,16 @@ def case(self, compiler, connection):
45
49
default_mql = self .default .as_mql (compiler , connection )
46
50
if not case_parts :
47
51
return default_mql
48
- return {
52
+ expr = {
49
53
"$switch" : {
50
54
"branches" : case_parts ,
51
55
"default" : default_mql ,
52
56
}
53
57
}
58
+ if as_path :
59
+ return {"$expr" : expr }
60
+
61
+ return expr
54
62
55
63
56
64
def col (self , compiler , connection , as_path = False ): # noqa: ARG001
@@ -76,34 +84,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
76
84
return f"{ prefix } { self .target .column } "
77
85
78
86
79
- def col_pairs (self , compiler , connection ):
87
+ def col_pairs (self , compiler , connection , as_path = False ):
80
88
cols = self .get_cols ()
81
89
if len (cols ) > 1 :
82
90
raise NotSupportedError ("ColPairs is not supported." )
83
- return cols [0 ].as_mql (compiler , connection )
91
+ return cols [0 ].as_mql (compiler , connection , as_path = as_path )
84
92
85
93
86
- def combined_expression (self , compiler , connection ):
94
+ def combined_expression (self , compiler , connection , as_path = False ):
87
95
expressions = [
88
- self .lhs .as_mql (compiler , connection ),
89
- self .rhs .as_mql (compiler , connection ),
96
+ self .lhs .as_mql (compiler , connection , as_path = as_path ),
97
+ self .rhs .as_mql (compiler , connection , as_path = as_path ),
90
98
]
91
99
return connection .ops .combine_expression (self .connector , expressions )
92
100
93
101
94
- def expression_wrapper (self , compiler , connection ):
95
- return self .expression .as_mql (compiler , connection )
102
+ def expression_wrapper (self , compiler , connection , as_path = False ):
103
+ return self .expression .as_mql (compiler , connection , as_path = as_path )
96
104
97
105
98
- def negated_expression (self , compiler , connection ):
99
- return {"$not" : expression_wrapper (self , compiler , connection )}
106
+ def negated_expression (self , compiler , connection , as_path = False ):
107
+ return {"$not" : expression_wrapper (self , compiler , connection , as_path = as_path )}
100
108
101
109
102
110
def order_by (self , compiler , connection ):
103
111
return self .expression .as_mql (compiler , connection )
104
112
105
113
106
- def query (self , compiler , connection , get_wrapping_pipeline = None ):
114
+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
107
115
subquery_compiler = self .get_compiler (connection = connection )
108
116
subquery_compiler .pre_sql_setup (with_col_aliases = False )
109
117
field_name , expr = subquery_compiler .columns [0 ]
@@ -145,14 +153,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None):
145
153
# Erase project_fields since the required value is projected above.
146
154
subquery .project_fields = None
147
155
compiler .subqueries .append (subquery )
156
+ if as_path :
157
+ return f"{ table_output } .{ field_name } "
148
158
return f"${ table_output } .{ field_name } "
149
159
150
160
151
161
def raw_sql (self , compiler , connection ): # noqa: ARG001
152
162
raise NotSupportedError ("RawSQL is not supported on MongoDB." )
153
163
154
164
155
- def ref (self , compiler , connection ): # noqa: ARG001
165
+ def ref (self , compiler , connection , as_path = False ): # noqa: ARG001
156
166
prefix = (
157
167
f"{ self .source .alias } ."
158
168
if isinstance (self .source , Col ) and self .source .alias != compiler .collection_name
@@ -162,32 +172,47 @@ def ref(self, compiler, connection): # noqa: ARG001
162
172
refs , _ = compiler .columns [self .ordinal - 1 ]
163
173
else :
164
174
refs = self .refs
165
- return f"${ prefix } { refs } "
175
+ if not as_path :
176
+ prefix = f"${ prefix } "
177
+ return f"{ prefix } { refs } "
166
178
167
179
168
- def star (self , compiler , connection ): # noqa: ARG001
180
+ def star (self , compiler , connection , ** extra ): # noqa: ARG001
169
181
return {"$literal" : True }
170
182
171
183
172
- def subquery (self , compiler , connection , get_wrapping_pipeline = None ):
173
- return self .query .as_mql (compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
184
+ def subquery (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
185
+ expr = self .query .as_mql (
186
+ compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_path = False
187
+ )
188
+ if as_path :
189
+ return {"$expr" : expr }
190
+ return expr
174
191
175
192
176
- def exists (self , compiler , connection , get_wrapping_pipeline = None ):
193
+ def exists (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
177
194
try :
178
- lhs_mql = subquery (self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
195
+ lhs_mql = subquery (
196
+ self ,
197
+ compiler ,
198
+ connection ,
199
+ get_wrapping_pipeline = get_wrapping_pipeline ,
200
+ as_path = as_path ,
201
+ )
179
202
except EmptyResultSet :
180
203
return Value (False ).as_mql (compiler , connection )
181
- return connection .mongo_operators ["isnull" ](lhs_mql , False )
204
+ if as_path :
205
+ return {"$expr" : connection .mongo_operators_match ["isnull" ](lhs_mql , False )}
206
+ return connection .mongo_operators_expr ["isnull" ](lhs_mql , False )
182
207
183
208
184
- def when (self , compiler , connection ):
185
- return self .condition .as_mql (compiler , connection )
209
+ def when (self , compiler , connection , as_path = False ):
210
+ return self .condition .as_mql (compiler , connection , as_path = as_path )
186
211
187
212
188
- def value (self , compiler , connection ): # noqa: ARG001
213
+ def value (self , compiler , connection , as_path = False ): # noqa: ARG001
189
214
value = self .value
190
- if isinstance (value , (list , int )):
215
+ if isinstance (value , (list , int )) and not as_path :
191
216
# Wrap lists & numbers in $literal to prevent ambiguity when Value
192
217
# appears in $project.
193
218
return {"$literal" : value }
@@ -209,6 +234,36 @@ def value(self, compiler, connection): # noqa: ARG001
209
234
return value
210
235
211
236
237
+ @staticmethod
238
+ def _is_constant_value (value ):
239
+ if isinstance (value , list | Array ):
240
+ iterable = value .get_source_expressions () if isinstance (value , Array ) else value
241
+ return all (_is_constant_value (e ) for e in iterable )
242
+ if is_direct_value (value ):
243
+ return True
244
+ return isinstance (value , Func | Value ) and not (
245
+ value .contains_aggregate
246
+ or value .contains_over_clause
247
+ or value .contains_column_references
248
+ or value .contains_subquery
249
+ )
250
+
251
+
252
+ @staticmethod
253
+ def _is_simple_column (lhs ):
254
+ while isinstance (lhs , KeyTransform ):
255
+ if "." in getattr (lhs , "key_name" , "" ):
256
+ return False
257
+ lhs = lhs .lhs
258
+ col = lhs .source if isinstance (lhs , Ref ) else lhs
259
+ # Foreign columns from parent cannot be addressed as single match
260
+ return isinstance (col , Col ) and col .alias is not None
261
+
262
+
263
+ def _is_simple_expression (self ):
264
+ return self .is_simple_column (self .lhs ) and self .is_constant_value (self .rhs )
265
+
266
+
212
267
def register_expressions ():
213
268
Case .as_mql = case
214
269
Col .as_mql = col
@@ -227,3 +282,6 @@ def register_expressions():
227
282
Subquery .as_mql = subquery
228
283
When .as_mql = when
229
284
Value .as_mql = value
285
+ BaseExpression .is_simple_expression = _is_simple_expression
286
+ BaseExpression .is_simple_column = _is_simple_column
287
+ BaseExpression .is_constant_value = _is_constant_value
0 commit comments