2
2
from django .db import DatabaseError , IntegrityError , NotSupportedError
3
3
from django .db .models import Count , Expression
4
4
from django .db .models .aggregates import Aggregate
5
- from django .db .models .expressions import OrderBy , Value
5
+ from django .db .models .expressions import Col , OrderBy , Value
6
6
from django .db .models .sql import compiler
7
- from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR
7
+ from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR , SINGLE
8
+
8
9
from django .utils .functional import cached_property
9
10
10
11
from .base import Cursor
@@ -16,6 +17,62 @@ class SQLCompiler(compiler.SQLCompiler):
16
17
17
18
query_class = MongoQuery
18
19
20
+ def pre_sql_setup (self ):
21
+ super ().pre_sql_setup ()
22
+ self .annotations = {}
23
+ group = {}
24
+ group_expressions = set ()
25
+ aggregation_idx = 1
26
+ for target , expr in self .query .annotation_select .items ():
27
+ if not expr .contains_aggregate :
28
+ result_expr = expr
29
+ else :
30
+ replacements = {}
31
+ for sub_expr in self ._get_aggregate_expressions (expr ):
32
+ alias = f"__aggregation{ aggregation_idx } "
33
+ group [alias ] = sub_expr .as_mql (self , self .connection )
34
+ aggregation_idx += 1
35
+ column_target = expr .output_field .__class__ ()
36
+ column_target .set_attributes_from_name (alias )
37
+ replacements [sub_expr ] = Col (self .collection_name , column_target )
38
+ result_expr = expr .replace_expressions (replacements )
39
+
40
+ self .annotations [target ] = result_expr
41
+ if group :
42
+ """
43
+ order_by = self.get_order_by()
44
+ for expr, (_, _, is_ref) in order_by:
45
+ # Skip references to the SELECT clause, as all expressions in
46
+ # the SELECT clause are already part of the GROUP BY.
47
+ if not is_ref:
48
+ group_expressions |= set(expr.get_group_by_cols())
49
+ having_group_by = self.having.get_group_by_cols() if self.having else ()
50
+ for expr in having_group_by:
51
+ group_expressions.add(expr)
52
+ """
53
+
54
+ ids = (
55
+ None
56
+ if not group_expressions
57
+ else {
58
+ col .target .column : col .as_mql (self , self .connection )
59
+ for col in group_expressions
60
+ }
61
+ )
62
+ group ["_id" ] = ids
63
+
64
+ pipeline = [{"$group" : group }]
65
+ if ids :
66
+ pipeline .append (
67
+ {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
68
+ )
69
+ if "_id" not in ids :
70
+ pipeline .append ({"$unSet" : "$_id" })
71
+
72
+ self ._group_pipeline = pipeline
73
+ else :
74
+ self ._group_pipeline = None
75
+
19
76
def execute_sql (
20
77
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
21
78
):
@@ -33,11 +90,13 @@ def execute_sql(
33
90
except EmptyResultSet :
34
91
return iter ([]) if result_type == MULTI else None
35
92
36
- return (
37
- (self ._make_result (row , columns ) for row in query .fetch ())
38
- if result_type == MULTI
39
- else self ._make_result (next (query .fetch ()), columns )
40
- )
93
+ if result_type == MULTI :
94
+ return (self ._make_result (row , columns ) for row in query .fetch ())
95
+
96
+ try :
97
+ return self ._make_result (next (query .fetch ()), columns )
98
+ except StopIteration :
99
+ return None
41
100
42
101
def results_iter (
43
102
self ,
@@ -64,7 +123,7 @@ def results_iter(
64
123
return rows
65
124
66
125
def has_results (self ):
67
- return bool (self .get_count ( check_exists = True ))
126
+ return bool (self .execute_sql ( SINGLE ))
68
127
69
128
def _make_result (self , entity , columns ):
70
129
"""
@@ -143,9 +202,9 @@ def build_query(self, columns=None):
143
202
"""Check if the query is supported and prepare a MongoQuery."""
144
203
self .check_query ()
145
204
query = self .query_class (self )
146
- query .project_fields = self .get_project_fields (columns )
147
- query .lookup_pipeline = self .get_lookup_pipeline ()
148
205
query .aggregation_stage = self .get_aggregation_pipeline ()
206
+ query .lookup_pipeline = self .get_lookup_pipeline ()
207
+ query .project_fields = self .get_project_fields (columns )
149
208
try :
150
209
query .mongo_query = {"$expr" : self .query .where .as_mql (self , self .connection )}
151
210
except FullResultSet :
@@ -185,7 +244,7 @@ def project_field(column):
185
244
186
245
return (
187
246
tuple (map (project_field , columns ))
188
- + tuple (self .query . annotation_select .items ())
247
+ + tuple (self .annotations .items ())
189
248
+ tuple (map (project_field , related_columns ))
190
249
)
191
250
@@ -250,52 +309,34 @@ def get_lookup_pipeline(self):
250
309
result += self .query .alias_map [alias ].as_mql (self , self .connection )
251
310
return result
252
311
253
- def get_aggregation_pipeline (self ):
254
- pipeline = None
255
- if any (isinstance (a , Aggregate ) for a in self .query .annotations .values ()):
256
- result = {}
257
- # self.get_group_by(self.select, [])
258
- for alias , annotation in self .query .annotation_select .items ():
259
- value = annotation .as_mql (self , self .connection )
260
- if isinstance (value , list ):
261
- value = value [0 ]
262
- result [alias ] = value
263
-
264
- expressions = set ()
265
- for expr , * _ in self .select :
266
- expressions |= set (expr .get_group_by_cols ())
267
- order_by = self .get_order_by ()
268
- for expr , (_ , _ , is_ref ) in order_by :
269
- # Skip references to the SELECT clause, as all expressions in
270
- # the SELECT clause are already part of the GROUP BY.
271
- if not is_ref :
272
- expressions |= set (expr .get_group_by_cols ())
273
- having_group_by = self .having .get_group_by_cols () if self .having else ()
274
- for expr in having_group_by :
275
- expressions .add (expr )
276
-
277
- ids = (
278
- None
279
- if not expressions
280
- else {col .target .column : col .as_mql (self , self .connection ) for col in expressions }
281
- )
282
- result ["_id" ] = ids
283
-
284
- pipeline = [{"$group" : result }]
285
- if ids :
286
- pipeline .append (
287
- {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
312
+ def _get_aggregate_expressions2 (self , expr ):
313
+ stack = [(None , expr )]
314
+ while stack :
315
+ parent , expr = stack .pop ()
316
+ if isinstance (expr , Aggregate ):
317
+ yield parent
318
+ elif hasattr (expr , "get_source_expressions" ):
319
+ stack .extend (
320
+ [((expr , idx ), se ) for idx , se in enumerate (expr .get_source_expressions ())]
288
321
)
289
- if "_id" not in ids :
290
- pipeline .append ({"$unSet" : "$_id" })
291
322
292
- return pipeline
323
+ def _get_aggregate_expressions (self , expr ):
324
+ stack = [expr ]
325
+ while stack :
326
+ expr = stack .pop ()
327
+ if isinstance (expr , Aggregate ):
328
+ yield expr
329
+ elif hasattr (expr , "get_source_expressions" ):
330
+ stack .extend (expr .get_source_expressions ())
331
+
332
+ def get_aggregation_pipeline (self ):
333
+ return self ._group_pipeline
293
334
294
335
def get_project_fields (self , columns = None ):
295
336
fields = {}
296
337
for name , expr in columns or []:
297
338
try :
298
- column = name if isinstance ( expr , Aggregate ) else expr .target .column
339
+ column = expr .target .column
299
340
except AttributeError :
300
341
# Generate the MQL for an annotation.
301
342
try :
0 commit comments