@@ -18,12 +18,20 @@ class SQLCompiler(compiler.SQLCompiler):
18
18
19
19
query_class = MongoQuery
20
20
_group_pipeline = None
21
+ aggregation_idx = 0
21
22
22
- def _prepare_expressions_for_pipeline (self , expression , target , aggregation_idx ):
23
+ def _get_colum_from_expression (self , expr , alias ):
24
+ column_target = expr .output_field .__class__ ()
25
+ column_target .db_column = alias
26
+ column_target .set_attributes_from_name (alias )
27
+ return Col (self .collection_name , column_target )
28
+
29
+ def _prepare_expressions_for_pipeline (self , expression , target ):
23
30
replacements = {}
24
31
group = {}
25
32
for sub_expr in self ._get_aggregate_expressions (expression ):
26
- alias = f"__aggregation{ aggregation_idx } " if sub_expr != expression else target
33
+ alias = f"__aggregation{ self .aggregation_idx } " if sub_expr != expression else target
34
+ self .aggregation_idx += 1
27
35
28
36
column_target = sub_expr .output_field .__class__ ()
29
37
column_target .db_column = alias
@@ -39,6 +47,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx)
39
47
group [alias ] = sub_expr .as_mql (self , self .connection )
40
48
replacing_expr = inner_column
41
49
50
+ sub_expr .as_mql (self , self .connection )
42
51
replacements [sub_expr ] = replacing_expr
43
52
return replacements , group
44
53
@@ -57,19 +66,16 @@ def pre_sql_setup(self, with_col_aliases=False):
57
66
group = {}
58
67
group_expressions = set ()
59
68
all_replacements = {}
60
- for idx , (target , expr ) in enumerate (self .query .annotation_select .items ()):
69
+ self .aggregation_idx = 0
70
+ for target , expr in self .query .annotation_select .items ():
61
71
if expr .contains_aggregate :
62
- replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target , idx )
63
- result_expr = expr .replace_expressions (replacements )
72
+ replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target )
64
73
all_replacements .update (replacements )
65
74
group .update (expr_group )
66
- else :
67
- result_expr = expr
68
75
group_expressions |= set (expr .get_group_by_cols ())
69
- self .annotations [target ] = result_expr
70
76
71
77
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
72
- self .having , None , len ( self . query . annotation_select )
78
+ self .having , None
73
79
)
74
80
all_replacements .update (having_replacements )
75
81
group .update (having_group )
@@ -110,7 +116,10 @@ def _ccc(col):
110
116
111
117
if not isinstance (col , Col ):
112
118
annotation_group_idx += 1
113
- return "__annotation_group_1"
119
+ alias = f"__annotation_group_{ annotation_group_idx } "
120
+ col_expr = self ._get_colum_from_expression (col , alias )
121
+ all_replacements [col ] = col_expr
122
+ col = col_expr
114
123
if self .collection_name == col .alias :
115
124
return col .target .column
116
125
return f"{ col .alias } { SEPARATOR } { col .target .column } "
@@ -124,13 +133,17 @@ def _ccc(col):
124
133
for col in group_expressions
125
134
}
126
135
)
136
+ self .annotations = {
137
+ target : expr .replace_expressions (all_replacements )
138
+ for target , expr in self .query .annotation_select .items ()
139
+ }
127
140
pipeline = []
128
141
if not ids :
129
142
group ["_id" ] = None
130
143
pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
131
144
pipeline .append (
132
145
{
133
- "$project " : {
146
+ "$addFields " : {
134
147
key : {
135
148
"$getField" : {
136
149
"input" : {"$arrayElemAt" : ["$group" , 0 ]},
@@ -174,6 +187,11 @@ def _ccc(col):
174
187
else :
175
188
self ._group_pipeline = None
176
189
190
+ self .annotations = {
191
+ target : expr .replace_expressions (all_replacements )
192
+ for target , expr in self .query .annotation_select .items ()
193
+ }
194
+
177
195
return pre_setup
178
196
179
197
def execute_sql (
@@ -334,9 +352,8 @@ def build_query(self, columns=None):
334
352
query .order_by (self ._get_ordering ())
335
353
query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
336
354
try :
337
- query .mongo_query = (
338
- {"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
339
- )
355
+ where = getattr (self , "where" , self .query .where )
356
+ query .mongo_query = {"$expr" : where .as_mql (self , self .connection )} if where else None
340
357
except FullResultSet :
341
358
query .mongo_query = {}
342
359
return query
@@ -515,7 +532,7 @@ def insert(self, docs, returning_fields=None):
515
532
class SQLDeleteCompiler (compiler .SQLDeleteCompiler , SQLCompiler ):
516
533
def execute_sql (self , result_type = MULTI ):
517
534
cursor = Cursor ()
518
- cursor .rowcount = self .build_query ([ self . query . get_meta (). pk ] ).delete ()
535
+ cursor .rowcount = self .build_query ().delete ()
519
536
return cursor
520
537
521
538
def check_query (self ):
0 commit comments