@@ -17,9 +17,10 @@ class SQLCompiler(compiler.SQLCompiler):
17
17
"""Base class for all Mongo compilers."""
18
18
19
19
query_class = MongoQuery
20
+ _group_pipeline = None
20
21
21
- def pre_sql_setup (self , * args , ** kargs ):
22
- pre_setup = super ().pre_sql_setup (* args , ** kargs )
22
+ def pre_sql_setup (self , with_col_aliases = False ):
23
+ pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
23
24
self .annotations = {}
24
25
group = {}
25
26
group_expressions = set ()
@@ -30,17 +31,20 @@ def pre_sql_setup(self, *args, **kargs):
30
31
else :
31
32
replacements = {}
32
33
for sub_expr in self ._get_aggregate_expressions (expr ):
33
- alias = f"__aggregation{ aggregation_idx } "
34
+ if sub_expr != expr :
35
+ alias = f"__aggregation{ aggregation_idx } "
36
+ aggregation_idx += 1
37
+ else :
38
+ alias = target
39
+ group_expressions |= set (sub_expr .get_group_by_cols ())
34
40
group [alias ] = sub_expr .as_mql (self , self .connection )
35
- aggregation_idx += 1
36
41
column_target = expr .output_field .__class__ ()
37
42
column_target .set_attributes_from_name (alias )
38
43
replacements [sub_expr ] = Col (self .collection_name , column_target )
39
44
result_expr = expr .replace_expressions (replacements )
40
45
41
46
self .annotations [target ] = result_expr
42
47
if group :
43
- """
44
48
order_by = self .get_order_by ()
45
49
for expr , (_ , _ , is_ref ) in order_by :
46
50
# Skip references to the SELECT clause, as all expressions in
@@ -50,7 +54,8 @@ def pre_sql_setup(self, *args, **kargs):
50
54
having_group_by = self .having .get_group_by_cols () if self .having else ()
51
55
for expr in having_group_by :
52
56
group_expressions .add (expr )
53
- """
57
+ if isinstance (self .query .group_by , tuple | list ):
58
+ group_expressions |= set (self .query .group_by )
54
59
55
60
ids = (
56
61
None
@@ -61,7 +66,6 @@ def pre_sql_setup(self, *args, **kargs):
61
66
}
62
67
)
63
68
group ["_id" ] = ids
64
-
65
69
pipeline = [{"$group" : group }]
66
70
if ids :
67
71
pipeline .append (
@@ -79,8 +83,8 @@ def pre_sql_setup(self, *args, **kargs):
79
83
def execute_sql (
80
84
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
81
85
):
82
- self .pre_sql_setup ()
83
86
# QuerySet.count()
87
+ self .pre_sql_setup ()
84
88
if self .query .annotations == {"__count" : Count ("*" )}:
85
89
return [self .get_count ()]
86
90
@@ -337,17 +341,6 @@ def get_lookup_pipeline(self):
337
341
result += self .query .alias_map [alias ].as_mql (self , self .connection )
338
342
return result
339
343
340
- def _get_aggregate_expressions2 (self , expr ):
341
- stack = [(None , expr )]
342
- while stack :
343
- parent , expr = stack .pop ()
344
- if isinstance (expr , Aggregate ):
345
- yield parent
346
- elif hasattr (expr , "get_source_expressions" ):
347
- stack .extend (
348
- [((expr , idx ), se ) for idx , se in enumerate (expr .get_source_expressions ())]
349
- )
350
-
351
344
def _get_aggregate_expressions (self , expr ):
352
345
stack = [expr ]
353
346
while stack :
@@ -494,4 +487,14 @@ def check_query(self):
494
487
495
488
496
489
class SQLAggregateCompiler (SQLCompiler ):
497
- pass
490
+ def build_query (self , columns = None ):
491
+ query = self .query_class (self )
492
+ query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
493
+
494
+ compiler = self .query .inner_query .get_compiler (
495
+ self .using ,
496
+ elide_empty = self .elide_empty ,
497
+ )
498
+ compiler .pre_sql_setup (with_col_aliases = False )
499
+ query .sub_query = compiler .build_query ()
500
+ return query
0 commit comments