@@ -17,12 +17,20 @@ class SQLCompiler(compiler.SQLCompiler):
17
17
18
18
query_class = MongoQuery
19
19
_group_pipeline = None
20
+ aggregation_idx = 0
20
21
21
- def _prepare_expressions_for_pipeline (self , expression , target , aggregation_idx ):
22
+ def _get_colum_from_expression (self , expr , alias ):
23
+ column_target = expr .output_field .__class__ ()
24
+ column_target .db_column = alias
25
+ column_target .set_attributes_from_name (alias )
26
+ return Col (self .collection_name , column_target )
27
+
28
+ def _prepare_expressions_for_pipeline (self , expression , target ):
22
29
replacements = {}
23
30
group = {}
24
31
for sub_expr in self ._get_aggregate_expressions (expression ):
25
- alias = f"__aggregation{ aggregation_idx } " if sub_expr != expression else target
32
+ alias = f"__aggregation{ self .aggregation_idx } " if sub_expr != expression else target
33
+ self .aggregation_idx += 1
26
34
27
35
column_target = sub_expr .output_field .__class__ ()
28
36
column_target .db_column = alias
@@ -38,6 +46,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, aggregation_idx)
38
46
group [alias ] = sub_expr .as_mql (self , self .connection )
39
47
replacing_expr = inner_column
40
48
49
+ sub_expr .as_mql (self , self .connection )
41
50
replacements [sub_expr ] = replacing_expr
42
51
return replacements , group
43
52
@@ -56,19 +65,16 @@ def pre_sql_setup(self, with_col_aliases=False):
56
65
group = {}
57
66
group_expressions = set ()
58
67
all_replacements = {}
59
- for idx , (target , expr ) in enumerate (self .query .annotation_select .items ()):
68
+ self .aggregation_idx = 0
69
+ for target , expr in self .query .annotation_select .items ():
60
70
if expr .contains_aggregate :
61
- replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target , idx )
62
- result_expr = expr .replace_expressions (replacements )
71
+ replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target )
63
72
all_replacements .update (replacements )
64
73
group .update (expr_group )
65
- else :
66
- result_expr = expr
67
74
group_expressions |= set (expr .get_group_by_cols ())
68
- self .annotations [target ] = result_expr
69
75
70
76
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
71
- self .having , None , len ( self . query . annotation_select )
77
+ self .having , None
72
78
)
73
79
all_replacements .update (having_replacements )
74
80
group .update (having_group )
@@ -109,7 +115,10 @@ def _ccc(col):
109
115
110
116
if not isinstance (col , Col ):
111
117
annotation_group_idx += 1
112
- return "__annotation_group_1"
118
+ alias = f"__annotation_group_{ annotation_group_idx } "
119
+ col_expr = self ._get_colum_from_expression (col , alias )
120
+ all_replacements [col ] = col_expr
121
+ col = col_expr
113
122
if self .collection_name == col .alias :
114
123
return col .target .column
115
124
return f"{ col .alias } { SEPARATOR } { col .target .column } "
@@ -123,13 +132,17 @@ def _ccc(col):
123
132
for col in group_expressions
124
133
}
125
134
)
135
+ self .annotations = {
136
+ target : expr .replace_expressions (all_replacements )
137
+ for target , expr in self .query .annotation_select .items ()
138
+ }
126
139
pipeline = []
127
140
if not ids :
128
141
group ["_id" ] = None
129
142
pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
130
143
pipeline .append (
131
144
{
132
- "$project " : {
145
+ "$addFields " : {
133
146
key : {
134
147
"$getField" : {
135
148
"input" : {"$arrayElemAt" : ["$group" , 0 ]},
@@ -173,6 +186,11 @@ def _ccc(col):
173
186
else :
174
187
self ._group_pipeline = None
175
188
189
+ self .annotations = {
190
+ target : expr .replace_expressions (all_replacements )
191
+ for target , expr in self .query .annotation_select .items ()
192
+ }
193
+
176
194
return pre_setup
177
195
178
196
def execute_sql (
@@ -306,9 +324,8 @@ def build_query(self, columns=None):
306
324
query .order_by (self ._get_ordering ())
307
325
query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
308
326
try :
309
- query .mongo_query = (
310
- {"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
311
- )
327
+ where = getattr (self , "where" , self .query .where )
328
+ query .mongo_query = {"$expr" : where .as_mql (self , self .connection )} if where else None
312
329
except FullResultSet :
313
330
query .mongo_query = {}
314
331
return query
@@ -487,7 +504,7 @@ def insert(self, docs, returning_fields=None):
487
504
class SQLDeleteCompiler (compiler .SQLDeleteCompiler , SQLCompiler ):
488
505
def execute_sql (self , result_type = MULTI ):
489
506
cursor = Cursor ()
490
- cursor .rowcount = self .build_query ([ self . query . get_meta (). pk ] ).delete ()
507
+ cursor .rowcount = self .build_query ().delete ()
491
508
return cursor
492
509
493
510
def check_query (self ):
0 commit comments