@@ -21,9 +21,11 @@ class SQLCompiler(compiler.SQLCompiler):
21
21
def pre_sql_setup (self , with_col_aliases = False ):
22
22
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
23
23
self .annotations = {}
24
+ # mongo_having = self.having.copy() if self.having else None
24
25
group = {}
25
26
group_expressions = set ()
26
27
aggregation_idx = 1
28
+ all_replacements = {}
27
29
for target , expr in self .query .annotation_select .items ():
28
30
if not expr .contains_aggregate :
29
31
result_expr = expr
@@ -41,7 +43,7 @@ def pre_sql_setup(self, with_col_aliases=False):
41
43
column_target .set_attributes_from_name (alias )
42
44
replacements [sub_expr ] = Col (self .collection_name , column_target )
43
45
result_expr = expr .replace_expressions (replacements )
44
-
46
+ all_replacements . update ( replacements )
45
47
self .annotations [target ] = result_expr
46
48
if group :
47
49
order_by = self .get_order_by ()
@@ -61,7 +63,9 @@ def pre_sql_setup(self, with_col_aliases=False):
61
63
if not group_expressions
62
64
else {
63
65
col .target .column : col .as_mql (self , self .connection )
66
+ # expression aren't needed in the group by clouse ()
64
67
for col in group_expressions
68
+ if isinstance (col , Col )
65
69
}
66
70
)
67
71
group ["_id" ] = ids
@@ -71,7 +75,17 @@ def pre_sql_setup(self, with_col_aliases=False):
71
75
{"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
72
76
)
73
77
if "_id" not in ids :
74
- pipeline .append ({"$unSet" : "$_id" })
78
+ pipeline .append ({"$unset" : "_id" })
79
+ if self .having :
80
+ pipeline .append (
81
+ {
82
+ "$match" : {
83
+ "$expr" : self .having .replace_expressions (all_replacements ).as_mql (
84
+ self , self .connection
85
+ )
86
+ }
87
+ }
88
+ )
75
89
76
90
self ._group_pipeline = pipeline
77
91
else :
@@ -207,7 +221,9 @@ def build_query(self, columns=None):
207
221
query .lookup_pipeline = self .get_lookup_pipeline ()
208
222
query .project_fields = self .get_project_fields (columns )
209
223
try :
210
- query .mongo_query = {"$expr" : self .query .where .as_mql (self , self .connection )}
224
+ query .mongo_query = (
225
+ {"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
226
+ )
211
227
except FullResultSet :
212
228
query .mongo_query = {}
213
229
query .order_by (self ._get_ordering ())
@@ -432,7 +448,8 @@ def execute_update(self, update_spec, **kwargs):
432
448
class SQLAggregateCompiler (SQLCompiler ):
433
449
def build_query (self , columns = None ):
434
450
query = self .query_class (self )
435
- query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
451
+ query .project_fields = self .get_project_fields (tuple (self .annotations .items ()))
452
+ query .aggregation_stage = self .get_aggregation_pipeline ()
436
453
437
454
compiler = self .query .inner_query .get_compiler (
438
455
self .using ,
@@ -441,3 +458,6 @@ def build_query(self, columns=None):
441
458
compiler .pre_sql_setup (with_col_aliases = False )
442
459
query .sub_query = compiler .build_query ()
443
460
return query
461
+
462
+ def _make_result (self , result , columns = None ):
463
+ return [result [k ] for k in self .query .annotation_select ]
0 commit comments