@@ -19,6 +19,29 @@ class SQLCompiler(compiler.SQLCompiler):
19
19
query_class = MongoQuery
20
20
_group_pipeline = None
21
21
22
+ def _prepare_expressions_for_pipeline (self , expression , target , aggregation_idx ):
23
+ replacements = {}
24
+ group = {}
25
+ for sub_expr in self ._get_aggregate_expressions (expression ):
26
+ alias = f"__aggregation{ aggregation_idx } " if sub_expr != expression else target
27
+
28
+ column_target = sub_expr .output_field .__class__ ()
29
+ column_target .db_column = alias
30
+ column_target .set_attributes_from_name (alias )
31
+ inner_column = Col (self .collection_name , column_target )
32
+ if sub_expr .distinct :
33
+ inner_expr = sub_expr .as_mql (self , self .connection , force_filters = True )
34
+ rhs = next (iter (inner_expr .values ()))
35
+ group [alias ] = {"$addToSet" : rhs }
36
+ replacing_expr = sub_expr .copy ()
37
+ replacing_expr .set_source_expressions ([inner_column ])
38
+ else :
39
+ group [alias ] = sub_expr .as_mql (self , self .connection )
40
+ replacing_expr = inner_column
41
+
42
+ replacements [sub_expr ] = replacing_expr
43
+ return replacements , group
44
+
22
45
@staticmethod
23
46
def _random_separtor ():
24
47
import random
@@ -33,39 +56,24 @@ def pre_sql_setup(self, with_col_aliases=False):
33
56
self .annotations = {}
34
57
group = {}
35
58
group_expressions = set ()
36
- aggregation_idx = 1
37
59
all_replacements = {}
38
- for target , expr in self .query .annotation_select .items ():
39
- if not expr .contains_aggregate :
40
- result_expr = expr
41
- else :
42
- replacements = {}
43
- for sub_expr in self ._get_aggregate_expressions (expr ):
44
- if sub_expr != expr :
45
- alias = f"__aggregation{ aggregation_idx } "
46
- aggregation_idx += 1
47
- else :
48
- alias = target
49
-
50
- column_target = expr .output_field .__class__ ()
51
- column_target .db_column = alias
52
- column_target .set_attributes_from_name (alias )
53
- inner_column = Col (self .collection_name , column_target )
54
- if sub_expr .distinct :
55
- inner_expr = sub_expr .as_mql (self , self .connection , force_filters = True )
56
- rhs = next (iter (inner_expr .values ()))
57
- group [alias ] = {"$addToSet" : rhs }
58
- replacing_expr = sub_expr .copy ()
59
- replacing_expr .set_source_expressions ([inner_column ])
60
- else :
61
- group [alias ] = sub_expr .as_mql (self , self .connection )
62
- replacing_expr = inner_column
63
-
64
- replacements [sub_expr ] = replacing_expr
60
+ for idx , (target , expr ) in enumerate (self .query .annotation_select .items ()):
61
+ if expr .contains_aggregate :
62
+ replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target , idx )
65
63
result_expr = expr .replace_expressions (replacements )
66
64
all_replacements .update (replacements )
65
+ group .update (expr_group )
66
+ else :
67
+ result_expr = expr
67
68
group_expressions |= set (expr .get_group_by_cols ())
68
69
self .annotations [target ] = result_expr
70
+
71
+ having_replacements , having_group = self ._prepare_expressions_for_pipeline (
72
+ self .having , None , len (self .query .annotation_select )
73
+ )
74
+ all_replacements .update (having_replacements )
75
+ group .update (having_group )
76
+
69
77
if group or self .query .group_by :
70
78
order_by = self .get_order_by ()
71
79
for expr , (_ , _ , is_ref ) in order_by :
@@ -95,7 +103,14 @@ def pre_sql_setup(self, with_col_aliases=False):
95
103
break
96
104
SEPARATOR = f"__{ random_string } __"
97
105
106
+ annotation_group_idx = 0
107
+
98
108
def _ccc (col ):
109
+ nonlocal annotation_group_idx
110
+
111
+ if not isinstance (col , Col ):
112
+ annotation_group_idx += 1
113
+ return "__annotation_group_1"
99
114
if self .collection_name == col .alias :
100
115
return col .target .column
101
116
return f"{ col .alias } { SEPARATOR } { col .target .column } "
@@ -107,7 +122,6 @@ def _ccc(col):
107
122
_ccc (col ): col .as_mql (self , self .connection )
108
123
# expression aren't needed in the group by clouse ()
109
124
for col in group_expressions
110
- if isinstance (col , Col )
111
125
}
112
126
)
113
127
pipeline = []
@@ -141,10 +155,7 @@ def _ccc(col):
141
155
else :
142
156
sets [key ] = value
143
157
144
- pipeline .append (
145
- # {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
146
- {"$addFields" : sets }
147
- )
158
+ pipeline .append ({"$addFields" : sets })
148
159
if "_id" not in sets :
149
160
pipeline .append ({"$unset" : "_id" })
150
161
0 commit comments