@@ -16,9 +16,10 @@ class SQLCompiler(compiler.SQLCompiler):
16
16
"""Base class for all Mongo compilers."""
17
17
18
18
query_class = MongoQuery
19
+ _group_pipeline = None
19
20
20
- def pre_sql_setup (self , * args , ** kargs ):
21
- pre_setup = super ().pre_sql_setup (* args , ** kargs )
21
+ def pre_sql_setup (self , with_col_aliases = False ):
22
+ pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
22
23
self .annotations = {}
23
24
group = {}
24
25
group_expressions = set ()
@@ -29,17 +30,20 @@ def pre_sql_setup(self, *args, **kargs):
29
30
else :
30
31
replacements = {}
31
32
for sub_expr in self ._get_aggregate_expressions (expr ):
32
- alias = f"__aggregation{ aggregation_idx } "
33
+ if sub_expr != expr :
34
+ alias = f"__aggregation{ aggregation_idx } "
35
+ aggregation_idx += 1
36
+ else :
37
+ alias = target
38
+ group_expressions |= set (sub_expr .get_group_by_cols ())
33
39
group [alias ] = sub_expr .as_mql (self , self .connection )
34
- aggregation_idx += 1
35
40
column_target = expr .output_field .__class__ ()
36
41
column_target .set_attributes_from_name (alias )
37
42
replacements [sub_expr ] = Col (self .collection_name , column_target )
38
43
result_expr = expr .replace_expressions (replacements )
39
44
40
45
self .annotations [target ] = result_expr
41
46
if group :
42
- """
43
47
order_by = self .get_order_by ()
44
48
for expr , (_ , _ , is_ref ) in order_by :
45
49
# Skip references to the SELECT clause, as all expressions in
@@ -49,7 +53,8 @@ def pre_sql_setup(self, *args, **kargs):
49
53
having_group_by = self .having .get_group_by_cols () if self .having else ()
50
54
for expr in having_group_by :
51
55
group_expressions .add (expr )
52
- """
56
+ if isinstance (self .query .group_by , tuple | list ):
57
+ group_expressions |= set (self .query .group_by )
53
58
54
59
ids = (
55
60
None
@@ -60,7 +65,6 @@ def pre_sql_setup(self, *args, **kargs):
60
65
}
61
66
)
62
67
group ["_id" ] = ids
63
-
64
68
pipeline = [{"$group" : group }]
65
69
if ids :
66
70
pipeline .append (
@@ -78,8 +82,8 @@ def pre_sql_setup(self, *args, **kargs):
78
82
def execute_sql (
79
83
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
80
84
):
81
- self .pre_sql_setup ()
82
85
# QuerySet.count()
86
+ self .pre_sql_setup ()
83
87
if self .query .annotations == {"__count" : Count ("*" )}:
84
88
return [self .get_count ()]
85
89
@@ -300,17 +304,6 @@ def get_lookup_pipeline(self):
300
304
result += self .query .alias_map [alias ].as_mql (self , self .connection )
301
305
return result
302
306
303
- def _get_aggregate_expressions2 (self , expr ):
304
- stack = [(None , expr )]
305
- while stack :
306
- parent , expr = stack .pop ()
307
- if isinstance (expr , Aggregate ):
308
- yield parent
309
- elif hasattr (expr , "get_source_expressions" ):
310
- stack .extend (
311
- [((expr , idx ), se ) for idx , se in enumerate (expr .get_source_expressions ())]
312
- )
313
-
314
307
def _get_aggregate_expressions (self , expr ):
315
308
stack = [expr ]
316
309
while stack :
@@ -437,4 +430,14 @@ def execute_update(self, update_spec, **kwargs):
437
430
438
431
439
432
class SQLAggregateCompiler (SQLCompiler ):
440
- pass
433
+ def build_query (self , columns = None ):
434
+ query = self .query_class (self )
435
+ query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
436
+
437
+ compiler = self .query .inner_query .get_compiler (
438
+ self .using ,
439
+ elide_empty = self .elide_empty ,
440
+ )
441
+ compiler .pre_sql_setup (with_col_aliases = False )
442
+ query .sub_query = compiler .build_query ()
443
+ return query
0 commit comments