@@ -17,9 +17,10 @@ class SQLCompiler(compiler.SQLCompiler):
17
17
"""Base class for all Mongo compilers."""
18
18
19
19
query_class = MongoQuery
20
+ _group_pipeline = None
20
21
21
- def pre_sql_setup (self , * args , ** kargs ):
22
- pre_setup = super ().pre_sql_setup (* args , ** kargs )
22
+ def pre_sql_setup (self , with_col_aliases = False ):
23
+ pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
23
24
self .annotations = {}
24
25
group = {}
25
26
group_expressions = set ()
@@ -30,17 +31,20 @@ def pre_sql_setup(self, *args, **kargs):
30
31
else :
31
32
replacements = {}
32
33
for sub_expr in self ._get_aggregate_expressions (expr ):
33
- alias = f"__aggregation{ aggregation_idx } "
34
+ if sub_expr != expr :
35
+ alias = f"__aggregation{ aggregation_idx } "
36
+ aggregation_idx += 1
37
+ else :
38
+ alias = target
39
+ group_expressions |= set (sub_expr .get_group_by_cols ())
34
40
group [alias ] = sub_expr .as_mql (self , self .connection )
35
- aggregation_idx += 1
36
41
column_target = expr .output_field .__class__ ()
37
42
column_target .set_attributes_from_name (alias )
38
43
replacements [sub_expr ] = Col (self .collection_name , column_target )
39
44
result_expr = expr .replace_expressions (replacements )
40
45
41
46
self .annotations [target ] = result_expr
42
47
if group :
43
- """
44
48
order_by = self .get_order_by ()
45
49
for expr , (_ , _ , is_ref ) in order_by :
46
50
# Skip references to the SELECT clause, as all expressions in
@@ -50,7 +54,8 @@ def pre_sql_setup(self, *args, **kargs):
50
54
having_group_by = self .having .get_group_by_cols () if self .having else ()
51
55
for expr in having_group_by :
52
56
group_expressions .add (expr )
53
- """
57
+ if isinstance (self .query .group_by , tuple | list ):
58
+ group_expressions |= set (self .query .group_by )
54
59
55
60
ids = (
56
61
None
@@ -61,7 +66,6 @@ def pre_sql_setup(self, *args, **kargs):
61
66
}
62
67
)
63
68
group ["_id" ] = ids
64
-
65
69
pipeline = [{"$group" : group }]
66
70
if ids :
67
71
pipeline .append (
@@ -79,8 +83,8 @@ def pre_sql_setup(self, *args, **kargs):
79
83
def execute_sql (
80
84
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
81
85
):
82
- self .pre_sql_setup ()
83
86
# QuerySet.count()
87
+ self .pre_sql_setup ()
84
88
if self .query .annotations == {"__count" : Count ("*" )}:
85
89
return [self .get_count ()]
86
90
@@ -339,17 +343,6 @@ def get_lookup_pipeline(self):
339
343
result += self .query .alias_map [alias ].as_mql (self , self .connection )
340
344
return result
341
345
342
- def _get_aggregate_expressions2 (self , expr ):
343
- stack = [(None , expr )]
344
- while stack :
345
- parent , expr = stack .pop ()
346
- if isinstance (expr , Aggregate ):
347
- yield parent
348
- elif hasattr (expr , "get_source_expressions" ):
349
- stack .extend (
350
- [((expr , idx ), se ) for idx , se in enumerate (expr .get_source_expressions ())]
351
- )
352
-
353
346
def _get_aggregate_expressions (self , expr ):
354
347
stack = [expr ]
355
348
while stack :
@@ -496,4 +489,14 @@ def check_query(self):
496
489
497
490
498
491
class SQLAggregateCompiler (SQLCompiler ):
499
- pass
492
+ def build_query (self , columns = None ):
493
+ query = self .query_class (self )
494
+ query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
495
+
496
+ compiler = self .query .inner_query .get_compiler (
497
+ self .using ,
498
+ elide_empty = self .elide_empty ,
499
+ )
500
+ compiler .pre_sql_setup (with_col_aliases = False )
501
+ query .sub_query = compiler .build_query ()
502
+ return query
0 commit comments