1
- from copy import deepcopy
2
-
3
1
from django .db import NotSupportedError
4
- from django .db .models .aggregates import Aggregate , Count , StdDev , Variance
5
- from django .db .models .expressions import Case , Func , Star , Value , When
6
- from django .db .models .functions import Now
2
+ from django .db .models .expressions import Func
7
3
from django .db .models .functions .comparison import Cast , Coalesce , Greatest , Least , NullIf
8
4
from django .db .models .functions .datetime import (
9
5
Extract ,
17
13
ExtractWeek ,
18
14
ExtractWeekDay ,
19
15
ExtractYear ,
16
+ Now ,
20
17
TruncBase ,
21
18
)
22
19
from django .db .models .functions .math import Ceil , Cot , Degrees , Log , Power , Radians , Random , Round
34
31
Trim ,
35
32
Upper ,
36
33
)
37
- from django .db .models .lookups import Exact
38
- from django .db .models .sql .where import WhereNode
39
34
40
35
from .query_utils import process_lhs
41
36
42
- MONGO_AGGREGATIONS = {
43
- Count : "sum" ,
44
- StdDev : "stdDev" , # Samp or Pop suffix added in aggregate().
45
- Variance : "stdDev" , # Likewise.
46
- }
47
37
MONGO_OPERATORS = {
48
38
Ceil : "ceil" ,
49
39
Coalesce : "ifNull" ,
68
58
}
69
59
70
60
71
- def aggregate (self , compiler , connection , ** extra_context ): # noqa: ARG001
72
- if self .filter :
73
- node = self .copy ()
74
- node .filter = None
75
- source_expressions = node .get_source_expressions ()
76
- condition = When (self .filter , then = source_expressions [0 ])
77
- node .set_source_expressions ([Case (condition )] + source_expressions [1 :])
78
- else :
79
- node = self
80
- lhs_mql = process_lhs (node , compiler , connection )
81
- operator = MONGO_AGGREGATIONS .get (self .__class__ , self .function .lower ())
82
- # Add suffixes to StdDev/Variance.
83
- if self .function .endswith ("_SAMP" ):
84
- operator += "Samp"
85
- elif self .function .endswith ("_POP" ):
86
- operator += "Pop"
87
- return {f"${ operator } " : lhs_mql }
88
-
89
-
90
61
def cast (self , compiler , connection ):
91
62
output_type = connection .data_types [self .output_field .get_internal_type ()]
92
63
lhs_mql = process_lhs (self , compiler , connection )[0 ]
@@ -117,42 +88,6 @@ def cot(self, compiler, connection):
117
88
return {"$divide" : [1 , {"$tan" : lhs_mql }]}
118
89
119
90
120
- def count (self , compiler , connection , resolve_inner_expression = False , ** extra_context ): # noqa: ARG001
121
- """
122
- When resolve_inner_expression is True, return the argument as MQL that
123
- resolves as a value. This is used to count different elements, so the inner
124
- values are returned to be pushed into a set.
125
- """
126
- if not self .distinct or resolve_inner_expression :
127
- if self .filter :
128
- node = self .copy ()
129
- node .filter = None
130
- source_expressions = node .get_source_expressions ()
131
- filter_ = deepcopy (self .filter )
132
- filter_ .add (
133
- WhereNode ([Exact (source_expressions [0 ], Value (None ))], negated = True ),
134
- filter_ .default ,
135
- )
136
- condition = When (filter_ , then = Value (1 ))
137
- node .set_source_expressions ([Case (condition )] + source_expressions [1 :])
138
- inner_expression = process_lhs (node , compiler , connection )
139
- else :
140
- lhs_mql = process_lhs (self , compiler , connection )
141
- null_cond = {"$in" : [{"$type" : lhs_mql }, ["missing" , "null" ]]}
142
- inner_expression = {
143
- "$cond" : {"if" : null_cond , "then" : None , "else" : lhs_mql if self .distinct else 1 }
144
- }
145
- if resolve_inner_expression :
146
- return inner_expression
147
- return {"$sum" : inner_expression }
148
- # If distinct=True or resolve_inner_expression=False, sum the size
149
- # of the set.
150
- lhs_mql = process_lhs (self , compiler , connection )
151
- # Subtract 1 if None is in the set (it shouldn't have been counted).
152
- exits_null = {"$cond" : {"if" : {"$in" : [{"$literal" : None }, lhs_mql ]}, "then" : - 1 , "else" : 0 }}
153
- return {"$add" : [{"$size" : lhs_mql }, exits_null ]}
154
-
155
-
156
91
def extract (self , compiler , connection ):
157
92
lhs_mql = process_lhs (self , compiler , connection )
158
93
operator = EXTRACT_OPERATORS .get (self .lookup_name )
@@ -223,10 +158,6 @@ def round_(self, compiler, connection):
223
158
return {"$round" : [expr .as_mql (compiler , connection ) for expr in self .get_source_expressions ()]}
224
159
225
160
226
- def star (self , compiler , connection ): # noqa: ARG001
227
- return {"$literal" : True }
228
-
229
-
230
161
def str_index (self , compiler , connection ):
231
162
lhs = process_lhs (self , compiler , connection )
232
163
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
@@ -261,12 +192,10 @@ def trunc(self, compiler, connection):
261
192
262
193
263
194
def register_functions ():
264
- Aggregate .as_mql = aggregate
265
195
Cast .as_mql = cast
266
196
Concat .as_mql = concat
267
197
ConcatPair .as_mql = concat_pair
268
198
Cot .as_mql = cot
269
- Count .as_mql = count
270
199
Extract .as_mql = extract
271
200
Func .as_mql = func
272
201
Left .as_mql = left
@@ -279,7 +208,6 @@ def register_functions():
279
208
Replace .as_mql = replace
280
209
Round .as_mql = round_
281
210
RTrim .as_mql = trim ("rtrim" )
282
- Star .as_mql = star
283
211
StrIndex .as_mql = str_index
284
212
Substr .as_mql = substr
285
213
Trim .as_mql = trim ("trim" )
0 commit comments