4
4
from django .db import DatabaseError , IntegrityError , NotSupportedError
5
5
from django .db .models import Count , Expression
6
6
from django .db .models .aggregates import Aggregate
7
- from django .db .models .expressions import OrderBy , Value
7
+ from django .db .models .expressions import Col , OrderBy , Value
8
8
from django .db .models .sql import compiler
9
9
from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR , SINGLE
10
10
from django .utils .functional import cached_property
@@ -18,6 +18,62 @@ class SQLCompiler(compiler.SQLCompiler):
18
18
19
19
query_class = MongoQuery
20
20
21
+ def pre_sql_setup (self ):
22
+ super ().pre_sql_setup ()
23
+ self .annotations = {}
24
+ group = {}
25
+ group_expressions = set ()
26
+ aggregation_idx = 1
27
+ for target , expr in self .query .annotation_select .items ():
28
+ if not expr .contains_aggregate :
29
+ result_expr = expr
30
+ else :
31
+ replacements = {}
32
+ for sub_expr in self ._get_aggregate_expressions (expr ):
33
+ alias = f"__aggregation{ aggregation_idx } "
34
+ group [alias ] = sub_expr .as_mql (self , self .connection )
35
+ aggregation_idx += 1
36
+ column_target = expr .output_field .__class__ ()
37
+ column_target .set_attributes_from_name (alias )
38
+ replacements [sub_expr ] = Col (self .collection_name , column_target )
39
+ result_expr = expr .replace_expressions (replacements )
40
+
41
+ self .annotations [target ] = result_expr
42
+ if group :
43
+ """
44
+ order_by = self.get_order_by()
45
+ for expr, (_, _, is_ref) in order_by:
46
+ # Skip references to the SELECT clause, as all expressions in
47
+ # the SELECT clause are already part of the GROUP BY.
48
+ if not is_ref:
49
+ group_expressions |= set(expr.get_group_by_cols())
50
+ having_group_by = self.having.get_group_by_cols() if self.having else ()
51
+ for expr in having_group_by:
52
+ group_expressions.add(expr)
53
+ """
54
+
55
+ ids = (
56
+ None
57
+ if not group_expressions
58
+ else {
59
+ col .target .column : col .as_mql (self , self .connection )
60
+ for col in group_expressions
61
+ }
62
+ )
63
+ group ["_id" ] = ids
64
+
65
+ pipeline = [{"$group" : group }]
66
+ if ids :
67
+ pipeline .append (
68
+ {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
69
+ )
70
+ if "_id" not in ids :
71
+ pipeline .append ({"$unSet" : "$_id" })
72
+
73
+ self ._group_pipeline = pipeline
74
+ else :
75
+ self ._group_pipeline = None
76
+
21
77
def execute_sql (
22
78
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
23
79
):
@@ -85,7 +141,7 @@ def results_iter(
85
141
return rows
86
142
87
143
def has_results (self ):
88
- return bool (self .get_count ( check_exists = True ))
144
+ return bool (self .execute_sql ( SINGLE ))
89
145
90
146
def _make_result (self , entity , columns ):
91
147
"""
@@ -172,9 +228,9 @@ def build_query(self, columns=None):
172
228
"""Check if the query is supported and prepare a MongoQuery."""
173
229
self .check_query ()
174
230
query = self .query_class (self )
175
- query .project_fields = self .get_project_fields (columns )
176
- query .lookup_pipeline = self .get_lookup_pipeline ()
177
231
query .aggregation_stage = self .get_aggregation_pipeline ()
232
+ query .lookup_pipeline = self .get_lookup_pipeline ()
233
+ query .project_fields = self .get_project_fields (columns )
178
234
try :
179
235
query .mongo_query = {"$expr" : self .query .where .as_mql (self , self .connection )}
180
236
except FullResultSet :
@@ -214,7 +270,7 @@ def project_field(column):
214
270
215
271
return (
216
272
tuple (map (project_field , columns ))
217
- + tuple (self .query . annotation_select .items ())
273
+ + tuple (self .annotations .items ())
218
274
+ tuple (map (project_field , related_columns ))
219
275
)
220
276
@@ -279,52 +335,34 @@ def get_lookup_pipeline(self):
279
335
result += self .query .alias_map [alias ].as_mql (self , self .connection )
280
336
return result
281
337
282
- def get_aggregation_pipeline (self ):
283
- pipeline = None
284
- if any (isinstance (a , Aggregate ) for a in self .query .annotations .values ()):
285
- result = {}
286
- # self.get_group_by(self.select, [])
287
- for alias , annotation in self .query .annotation_select .items ():
288
- value = annotation .as_mql (self , self .connection )
289
- if isinstance (value , list ):
290
- value = value [0 ]
291
- result [alias ] = value
292
-
293
- expressions = set ()
294
- for expr , * _ in self .select :
295
- expressions |= set (expr .get_group_by_cols ())
296
- order_by = self .get_order_by ()
297
- for expr , (_ , _ , is_ref ) in order_by :
298
- # Skip references to the SELECT clause, as all expressions in
299
- # the SELECT clause are already part of the GROUP BY.
300
- if not is_ref :
301
- expressions |= set (expr .get_group_by_cols ())
302
- having_group_by = self .having .get_group_by_cols () if self .having else ()
303
- for expr in having_group_by :
304
- expressions .add (expr )
305
-
306
- ids = (
307
- None
308
- if not expressions
309
- else {col .target .column : col .as_mql (self , self .connection ) for col in expressions }
310
- )
311
- result ["_id" ] = ids
312
-
313
- pipeline = [{"$group" : result }]
314
- if ids :
315
- pipeline .append (
316
- {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
338
+ def _get_aggregate_expressions2 (self , expr ):
339
+ stack = [(None , expr )]
340
+ while stack :
341
+ parent , expr = stack .pop ()
342
+ if isinstance (expr , Aggregate ):
343
+ yield parent
344
+ elif hasattr (expr , "get_source_expressions" ):
345
+ stack .extend (
346
+ [((expr , idx ), se ) for idx , se in enumerate (expr .get_source_expressions ())]
317
347
)
318
- if "_id" not in ids :
319
- pipeline .append ({"$unSet" : "$_id" })
320
348
321
- return pipeline
349
+ def _get_aggregate_expressions (self , expr ):
350
+ stack = [expr ]
351
+ while stack :
352
+ expr = stack .pop ()
353
+ if isinstance (expr , Aggregate ):
354
+ yield expr
355
+ elif hasattr (expr , "get_source_expressions" ):
356
+ stack .extend (expr .get_source_expressions ())
357
+
358
+ def get_aggregation_pipeline (self ):
359
+ return self ._group_pipeline
322
360
323
361
def get_project_fields (self , columns = None ):
324
362
fields = {}
325
363
for name , expr in columns or []:
326
364
try :
327
- column = name if isinstance ( expr , Aggregate ) else expr .target .column
365
+ column = expr .target .column
328
366
except AttributeError :
329
367
# Generate the MQL for an annotation.
330
368
try :
0 commit comments