@@ -21,9 +21,11 @@ class SQLCompiler(compiler.SQLCompiler):
21
21
def pre_sql_setup (self , with_col_aliases = False ):
22
22
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
23
23
self .annotations = {}
24
+ # mongo_having = self.having.copy() if self.having else None
24
25
group = {}
25
26
group_expressions = set ()
26
27
aggregation_idx = 1
28
+ all_replacements = {}
27
29
for target , expr in self .query .annotation_select .items ():
28
30
if not expr .contains_aggregate :
29
31
result_expr = expr
@@ -41,7 +43,7 @@ def pre_sql_setup(self, with_col_aliases=False):
41
43
column_target .set_attributes_from_name (alias )
42
44
replacements [sub_expr ] = Col (self .collection_name , column_target )
43
45
result_expr = expr .replace_expressions (replacements )
44
-
46
+ all_replacements . update ( replacements )
45
47
self .annotations [target ] = result_expr
46
48
if group :
47
49
order_by = self .get_order_by ()
@@ -61,7 +63,9 @@ def pre_sql_setup(self, with_col_aliases=False):
61
63
if not group_expressions
62
64
else {
63
65
col .target .column : col .as_mql (self , self .connection )
66
+ # expression aren't needed in the group by clouse ()
64
67
for col in group_expressions
68
+ if isinstance (col , Col )
65
69
}
66
70
)
67
71
group ["_id" ] = ids
@@ -71,7 +75,17 @@ def pre_sql_setup(self, with_col_aliases=False):
71
75
{"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
72
76
)
73
77
if "_id" not in ids :
74
- pipeline .append ({"$unSet" : "$_id" })
78
+ pipeline .append ({"$unset" : "_id" })
79
+ if self .having :
80
+ pipeline .append (
81
+ {
82
+ "$match" : {
83
+ "$expr" : self .having .replace_expressions (all_replacements ).as_mql (
84
+ self , self .connection
85
+ )
86
+ }
87
+ }
88
+ )
75
89
76
90
self ._group_pipeline = pipeline
77
91
else :
@@ -206,7 +220,9 @@ def build_query(self, columns=None):
206
220
query .lookup_pipeline = self .get_lookup_pipeline ()
207
221
query .project_fields = self .get_project_fields (columns )
208
222
try :
209
- query .mongo_query = {"$expr" : self .query .where .as_mql (self , self .connection )}
223
+ query .mongo_query = (
224
+ {"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
225
+ )
210
226
except FullResultSet :
211
227
query .mongo_query = {}
212
228
query .order_by (self ._get_ordering ())
@@ -431,7 +447,8 @@ def execute_update(self, update_spec, **kwargs):
431
447
class SQLAggregateCompiler (SQLCompiler ):
432
448
def build_query (self , columns = None ):
433
449
query = self .query_class (self )
434
- query .project_fields = self .get_project_fields (tuple (self .query .annotation_select .items ()))
450
+ query .project_fields = self .get_project_fields (tuple (self .annotations .items ()))
451
+ query .aggregation_stage = self .get_aggregation_pipeline ()
435
452
436
453
compiler = self .query .inner_query .get_compiler (
437
454
self .using ,
@@ -440,3 +457,6 @@ def build_query(self, columns=None):
440
457
compiler .pre_sql_setup (with_col_aliases = False )
441
458
query .sub_query = compiler .build_query ()
442
459
return query
460
+
461
+ def _make_result (self , result , columns = None ):
462
+ return [result [k ] for k in self .query .annotation_select ]
0 commit comments