@@ -22,9 +22,11 @@ class SQLCompiler(compiler.SQLCompiler):
22
22
def pre_sql_setup (self , with_col_aliases = False ):
23
23
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
24
24
self .annotations = {}
25
+ # mongo_having = self.having.copy() if self.having else None
25
26
group = {}
26
27
group_expressions = set ()
27
28
aggregation_idx = 1
29
+ all_replacements = {}
28
30
for target , expr in self .query .annotation_select .items ():
29
31
if not expr .contains_aggregate :
30
32
result_expr = expr
@@ -42,7 +44,7 @@ def pre_sql_setup(self, with_col_aliases=False):
42
44
column_target .set_attributes_from_name (alias )
43
45
replacements [sub_expr ] = Col (self .collection_name , column_target )
44
46
result_expr = expr .replace_expressions (replacements )
45
-
47
+ all_replacements . update ( replacements )
46
48
self .annotations [target ] = result_expr
47
49
if group :
48
50
order_by = self .get_order_by ()
@@ -62,7 +64,9 @@ def pre_sql_setup(self, with_col_aliases=False):
62
64
if not group_expressions
63
65
else {
64
66
col .target .column : col .as_mql (self , self .connection )
67
+ # expression aren't needed in the group by clouse ()
65
68
for col in group_expressions
69
+ if isinstance (col , Col )
66
70
}
67
71
)
68
72
group ["_id" ] = ids
@@ -72,7 +76,17 @@ def pre_sql_setup(self, with_col_aliases=False):
72
76
{"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
73
77
)
74
78
if "_id" not in ids :
75
- pipeline .append ({"$unSet" : "$_id" })
79
+ pipeline .append ({"$unset" : "_id" })
80
+ if self .having :
81
+ pipeline .append (
82
+ {
83
+ "$match" : {
84
+ "$expr" : self .having .replace_expressions (all_replacements ).as_mql (
85
+ self , self .connection
86
+ )
87
+ }
88
+ }
89
+ )
76
90
77
91
self ._group_pipeline = pipeline
78
92
else :
@@ -238,7 +252,9 @@ def build_query(self, columns=None):
238
252
query .lookup_pipeline = self .get_lookup_pipeline ()
239
253
query .project_fields = self .get_project_fields (columns )
240
254
try :
241
- query .mongo_query = {"$expr" : self .query .where .as_mql (self , self .connection )}
255
+ query .mongo_query = (
256
+ {"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
257
+ )
242
258
except FullResultSet :
243
259
query .mongo_query = {}
244
260
query .order_by (self ._get_ordering ())
@@ -489,7 +505,8 @@ def check_query(self):
489
505
class SQLAggregateCompiler (SQLCompiler ):
490
506
def build_query (self , columns = None ):
491
507
query = self .query_class (self )
492
- query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
508
+ query .project_fields = self .get_project_fields (tuple (self .annotations .items ()))
509
+ query .aggregation_stage = self .get_aggregation_pipeline ()
493
510
494
511
compiler = self .query .inner_query .get_compiler (
495
512
self .using ,
@@ -498,3 +515,6 @@ def build_query(self, columns=None):
498
515
compiler .pre_sql_setup (with_col_aliases = False )
499
516
query .sub_query = compiler .build_query ()
500
517
return query
518
+
519
+ def _make_result (self , result , columns = None ):
520
+ return [result [k ] for k in self .query .annotation_select ]
0 commit comments