@@ -18,6 +18,29 @@ class SQLCompiler(compiler.SQLCompiler):
18
18
query_class = MongoQuery
19
19
_group_pipeline = None
20
20
21
+ def _prepare_expressions_for_pipeline (self , expression , target , aggregation_idx ):
22
+ replacements = {}
23
+ group = {}
24
+ for sub_expr in self ._get_aggregate_expressions (expression ):
25
+ alias = f"__aggregation{ aggregation_idx } " if sub_expr != expression else target
26
+
27
+ column_target = sub_expr .output_field .__class__ ()
28
+ column_target .db_column = alias
29
+ column_target .set_attributes_from_name (alias )
30
+ inner_column = Col (self .collection_name , column_target )
31
+ if sub_expr .distinct :
32
+ inner_expr = sub_expr .as_mql (self , self .connection , force_filters = True )
33
+ rhs = next (iter (inner_expr .values ()))
34
+ group [alias ] = {"$addToSet" : rhs }
35
+ replacing_expr = sub_expr .copy ()
36
+ replacing_expr .set_source_expressions ([inner_column ])
37
+ else :
38
+ group [alias ] = sub_expr .as_mql (self , self .connection )
39
+ replacing_expr = inner_column
40
+
41
+ replacements [sub_expr ] = replacing_expr
42
+ return replacements , group
43
+
21
44
@staticmethod
22
45
def _random_separtor ():
23
46
import random
@@ -32,39 +55,24 @@ def pre_sql_setup(self, with_col_aliases=False):
32
55
self .annotations = {}
33
56
group = {}
34
57
group_expressions = set ()
35
- aggregation_idx = 1
36
58
all_replacements = {}
37
- for target , expr in self .query .annotation_select .items ():
38
- if not expr .contains_aggregate :
39
- result_expr = expr
40
- else :
41
- replacements = {}
42
- for sub_expr in self ._get_aggregate_expressions (expr ):
43
- if sub_expr != expr :
44
- alias = f"__aggregation{ aggregation_idx } "
45
- aggregation_idx += 1
46
- else :
47
- alias = target
48
-
49
- column_target = expr .output_field .__class__ ()
50
- column_target .db_column = alias
51
- column_target .set_attributes_from_name (alias )
52
- inner_column = Col (self .collection_name , column_target )
53
- if sub_expr .distinct :
54
- inner_expr = sub_expr .as_mql (self , self .connection , force_filters = True )
55
- rhs = next (iter (inner_expr .values ()))
56
- group [alias ] = {"$addToSet" : rhs }
57
- replacing_expr = sub_expr .copy ()
58
- replacing_expr .set_source_expressions ([inner_column ])
59
- else :
60
- group [alias ] = sub_expr .as_mql (self , self .connection )
61
- replacing_expr = inner_column
62
-
63
- replacements [sub_expr ] = replacing_expr
59
+ for idx , (target , expr ) in enumerate (self .query .annotation_select .items ()):
60
+ if expr .contains_aggregate :
61
+ replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target , idx )
64
62
result_expr = expr .replace_expressions (replacements )
65
63
all_replacements .update (replacements )
64
+ group .update (expr_group )
65
+ else :
66
+ result_expr = expr
66
67
group_expressions |= set (expr .get_group_by_cols ())
67
68
self .annotations [target ] = result_expr
69
+
70
+ having_replacements , having_group = self ._prepare_expressions_for_pipeline (
71
+ self .having , None , len (self .query .annotation_select )
72
+ )
73
+ all_replacements .update (having_replacements )
74
+ group .update (having_group )
75
+
68
76
if group or self .query .group_by :
69
77
order_by = self .get_order_by ()
70
78
for expr , (_ , _ , is_ref ) in order_by :
@@ -94,7 +102,14 @@ def pre_sql_setup(self, with_col_aliases=False):
94
102
break
95
103
SEPARATOR = f"__{ random_string } __"
96
104
105
+ annotation_group_idx = 0
106
+
97
107
def _ccc (col ):
108
+ nonlocal annotation_group_idx
109
+
110
+ if not isinstance (col , Col ):
111
+ annotation_group_idx += 1
112
+ return "__annotation_group_1"
98
113
if self .collection_name == col .alias :
99
114
return col .target .column
100
115
return f"{ col .alias } { SEPARATOR } { col .target .column } "
@@ -106,7 +121,6 @@ def _ccc(col):
106
121
_ccc (col ): col .as_mql (self , self .connection )
107
122
# expression aren't needed in the group by clouse ()
108
123
for col in group_expressions
109
- if isinstance (col , Col )
110
124
}
111
125
)
112
126
pipeline = []
@@ -140,10 +154,7 @@ def _ccc(col):
140
154
else :
141
155
sets [key ] = value
142
156
143
- pipeline .append (
144
- # {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
145
- {"$addFields" : sets }
146
- )
157
+ pipeline .append ({"$addFields" : sets })
147
158
if "_id" not in sets :
148
159
pipeline .append ({"$unset" : "_id" })
149
160
0 commit comments