1
+ import itertools
2
+
1
3
from django .core .exceptions import EmptyResultSet , FullResultSet
2
4
from django .db import DatabaseError , IntegrityError , NotSupportedError
3
5
from django .db .models import Count , Expression
@@ -17,22 +19,33 @@ class SQLCompiler(compiler.SQLCompiler):
17
19
"""Base class for all Mongo compilers."""
18
20
19
21
query_class = MongoQuery
20
- _group_pipeline = None
21
- aggregation_idx = 0
22
-
23
- def _get_colum_from_expression (self , expr , alias ):
22
+ SEPARATOR = "10__MESSI__3"
23
+
24
+ def _get_group_alias_column (self , col , annotation_group_idx ):
25
+ """Generate alias and replacement for group columns."""
26
+ replacement = None
27
+ if not isinstance (col , Col ):
28
+ alias = f"__annotation_group{ next (annotation_group_idx )} "
29
+ col_expr = self ._get_column_from_expression (col , alias )
30
+ replacement = col_expr
31
+ col = col_expr
32
+ if self .collection_name == col .alias :
33
+ return col .target .column , replacement
34
+ return f"{ col .alias } { self .SEPARATOR } { col .target .column } " , replacement
35
+
36
+ def _get_column_from_expression (self , expr , alias ):
37
+ """Get column target from expression."""
24
38
column_target = expr .output_field .__class__ ()
25
39
column_target .db_column = alias
26
40
column_target .set_attributes_from_name (alias )
27
41
return Col (self .collection_name , column_target )
28
42
29
- def _prepare_expressions_for_pipeline (self , expression , target ):
43
+ def _prepare_expressions_for_pipeline (self , expression , target , count ):
44
+ """Prepare expressions for the MongoDB aggregation pipeline."""
30
45
replacements = {}
31
46
group = {}
32
47
for sub_expr in self ._get_aggregate_expressions (expression ):
33
- alias = f"__aggregation{ self .aggregation_idx } " if sub_expr != expression else target
34
- self .aggregation_idx += 1
35
-
48
+ alias = f"__aggregation{ next (count )} " if sub_expr != expression else target
36
49
column_target = sub_expr .output_field .__class__ ()
37
50
column_target .db_column = alias
38
51
column_target .set_attributes_from_name (alias )
@@ -55,127 +68,109 @@ def _prepare_expressions_for_pipeline(self, expression, target):
55
68
replacements [sub_expr ] = replacing_expr
56
69
return replacements , group
57
70
58
- @staticmethod
59
- def _random_separtor ():
60
- import random
61
- import string
62
-
63
- size = 6
64
- chars = string .ascii_uppercase + string .digits
65
- return "" .join (random .choice (chars ) for _ in range (size )) # noqa: S311
66
-
67
- def pre_sql_setup (self , with_col_aliases = False ):
68
- pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
69
- self .annotations = {}
71
+ def _prepare_annotations_for_group_pipeline (self ):
72
+ """Prepare annotations for the MongoDB aggregation pipeline."""
73
+ replacements = {}
70
74
group = {}
71
- group_expressions = set ()
72
- all_replacements = {}
73
- self .aggregation_idx = 0
75
+ count = itertools .count (start = 1 )
74
76
for target , expr in self .query .annotation_select .items ():
75
77
if expr .contains_aggregate :
76
- replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target )
77
- all_replacements .update (replacements )
78
+ new_replacements , expr_group = self ._prepare_expressions_for_pipeline (
79
+ expr , target , count
80
+ )
81
+ replacements .update (new_replacements )
78
82
group .update (expr_group )
79
- group_expressions |= set (expr .get_group_by_cols ())
80
83
81
84
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
82
- self .having , None
85
+ self .having , None , count
83
86
)
84
- all_replacements .update (having_replacements )
87
+ replacements .update (having_replacements )
85
88
group .update (having_group )
89
+ return group , replacements
86
90
87
- if group or self .query .group_by :
88
- order_by = self .get_order_by ()
89
- for expr , (_ , _ , is_ref ) in order_by :
90
- # Skip references to the SELECT clause, as all expressions in
91
- # the SELECT clause are already part of the GROUP BY.
92
- if not is_ref :
93
- group_expressions |= set (expr .get_group_by_cols ())
94
-
95
- for expr , * _ in self .select :
91
+ def _get_group_id_expressions (self ):
92
+ """Generate group ID expressions for the aggregation pipeline."""
93
+ group_expressions = set ()
94
+ replacements = {}
95
+ order_by = self .get_order_by ()
96
+ for expr , (_ , _ , is_ref ) in order_by :
97
+ if not is_ref :
96
98
group_expressions |= set (expr .get_group_by_cols ())
97
99
98
- having_group_by = self .having .get_group_by_cols () if self .having else ()
99
- for expr in having_group_by :
100
- group_expressions .add (expr )
101
- if isinstance (self .query .group_by , tuple | list ):
102
- group_expressions |= set (self .query .group_by )
103
- elif self .query .group_by is None :
104
- group_expressions = set ()
100
+ for expr , * _ in self .select :
101
+ group_expressions |= set (expr .get_group_by_cols ())
105
102
106
- all_strings = "" .join (
107
- str (col .as_mql (self , self .connection )) for col in group_expressions
108
- )
103
+ having_group_by = self .having .get_group_by_cols () if self .having else ()
104
+ for expr in having_group_by :
105
+ group_expressions .add (expr )
106
+ if isinstance (self .query .group_by , tuple | list ):
107
+ group_expressions |= set (self .query .group_by )
108
+ elif self .query .group_by is None :
109
+ group_expressions = set ()
109
110
110
- while True :
111
- random_string = self ._random_separtor ()
112
- if random_string not in all_strings :
113
- break
114
- SEPARATOR = f"__{ random_string } __"
115
-
116
- annotation_group_idx = 0
117
-
118
- def _ccc (col ):
119
- nonlocal annotation_group_idx
120
-
121
- if not isinstance (col , Col ):
122
- annotation_group_idx += 1
123
- alias = f"__annotation_group_{ annotation_group_idx } "
124
- col_expr = self ._get_colum_from_expression (col , alias )
125
- all_replacements [col ] = col_expr
126
- col = col_expr
127
- if self .collection_name == col .alias :
128
- return col .target .column
129
- return f"{ col .alias } { SEPARATOR } { col .target .column } "
130
-
131
- ids = (
132
- None
133
- if not group_expressions
134
- else {
135
- _ccc (col ): col .as_mql (self , self .connection )
136
- # expression aren't needed in the group by clouse ()
137
- for col in group_expressions
138
- }
139
- )
140
- self .annotations = {
141
- target : expr .replace_expressions (all_replacements )
142
- for target , expr in self .query .annotation_select .items ()
143
- }
144
- pipeline = []
145
- if not ids :
146
- group ["_id" ] = None
147
- pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
148
- pipeline .append (
149
- {
150
- "$addFields" : {
151
- key : {
152
- "$getField" : {
153
- "input" : {"$arrayElemAt" : ["$group" , 0 ]},
154
- "field" : key ,
155
- }
111
+ if not group_expressions :
112
+ ids = None
113
+ else :
114
+ annotation_group_idx = itertools .count (start = 1 )
115
+ ids = {}
116
+ for col in group_expressions :
117
+ alias , replacement = self ._get_group_alias_column (col , annotation_group_idx )
118
+ ids [alias ] = col .as_mql (self , self .connection )
119
+ if replacement is not None :
120
+ replacements [col ] = replacement
121
+
122
+ return ids , replacements
123
+
124
+ def _build_group_pipeline (self , ids , group ):
125
+ """Build the aggregation pipeline for grouping."""
126
+ pipeline = []
127
+ if not ids :
128
+ group ["_id" ] = None
129
+ pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
130
+ pipeline .append (
131
+ {
132
+ "$addFields" : {
133
+ key : {
134
+ "$getField" : {
135
+ "input" : {"$arrayElemAt" : ["$group" , 0 ]},
136
+ "field" : key ,
156
137
}
157
- for key in group
158
138
}
139
+ for key in group
159
140
}
160
- )
161
- else :
162
- group ["_id" ] = ids
163
- pipeline .append ({"$group" : group })
164
- sets = {}
165
- for key in ids :
166
- value = f"$_id.{ key } "
167
- if SEPARATOR in key :
168
- subtable , field = key .split (SEPARATOR )
169
- if subtable not in sets :
170
- sets [subtable ] = {}
171
- sets [subtable ][field ] = value
172
- else :
173
- sets [key ] = value
174
-
175
- pipeline .append ({"$addFields" : sets })
176
- if "_id" not in sets :
177
- pipeline .append ({"$unset" : "_id" })
141
+ }
142
+ )
143
+ else :
144
+ group ["_id" ] = ids
145
+ pipeline .append ({"$group" : group })
146
+ sets = {}
147
+ for key in ids :
148
+ value = f"$_id.{ key } "
149
+ if self .SEPARATOR in key :
150
+ subtable , field = key .split (self .SEPARATOR )
151
+ if subtable not in sets :
152
+ sets [subtable ] = {}
153
+ sets [subtable ][field ] = value
154
+ else :
155
+ sets [key ] = value
156
+
157
+ pipeline .append ({"$addFields" : sets })
158
+ if "_id" not in sets :
159
+ pipeline .append ({"$unset" : "_id" })
178
160
161
+ return pipeline
162
+
163
+ def pre_sql_setup (self , with_col_aliases = False ):
164
+ pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
165
+ group , all_replacements = self ._prepare_annotations_for_group_pipeline ()
166
+
167
+ # The query.group_by is either None (no GROUP BY at all), True
168
+ # (group by select fields), or a list of expressions to be added
169
+ # to the group by.
170
+ if group or self .query .group_by :
171
+ ids , replacements = self ._get_group_id_expressions ()
172
+ all_replacements .update (replacements )
173
+ pipeline = self ._build_group_pipeline (ids , group )
179
174
if self .having :
180
175
pipeline .append (
181
176
{
@@ -186,7 +181,6 @@ def _ccc(col):
186
181
}
187
182
}
188
183
)
189
-
190
184
self ._group_pipeline = pipeline
191
185
else :
192
186
self ._group_pipeline = None
@@ -201,7 +195,6 @@ def _ccc(col):
201
195
def execute_sql (
202
196
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
203
197
):
204
- # QuerySet.count()
205
198
self .pre_sql_setup ()
206
199
columns = self .get_columns ()
207
200
try :
@@ -291,34 +284,6 @@ def check_query(self):
291
284
if any (key .startswith ("_prefetch_related_" ) for key in self .query .extra ):
292
285
raise NotSupportedError ("QuerySet.prefetch_related() is not supported on MongoDB." )
293
286
raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
294
- if any (
295
- isinstance (a , Aggregate ) and not isinstance (a , Count )
296
- for a in self .query .annotations .values ()
297
- ):
298
- # raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
299
- pass
300
-
301
- def get_count (self , check_exists = False ):
302
- """
303
- Count objects matching the current filters / constraints.
304
-
305
- If `check_exists` is True, only check if any object matches.
306
- """
307
- kwargs = {}
308
- # If this query is sliced, the limits will be set on the subquery.
309
- inner_query = getattr (self .query , "inner_query" , None )
310
- low_mark = inner_query .low_mark if inner_query else 0
311
- high_mark = inner_query .high_mark if inner_query else None
312
- if low_mark > 0 :
313
- kwargs ["skip" ] = low_mark
314
- if check_exists :
315
- kwargs ["limit" ] = 1
316
- elif high_mark is not None :
317
- kwargs ["limit" ] = high_mark - low_mark
318
- try :
319
- return self .build_query ().count (** kwargs )
320
- except EmptyResultSet :
321
- return 0
322
287
323
288
def build_query (self , columns = None ):
324
289
"""Check if the query is supported and prepare a MongoQuery."""
@@ -511,6 +476,7 @@ def insert(self, docs, returning_fields=None):
511
476
512
477
class SQLDeleteCompiler (compiler .SQLDeleteCompiler , SQLCompiler ):
513
478
def execute_sql (self , result_type = MULTI ):
479
+ self .pre_sql_setup ()
514
480
cursor = Cursor ()
515
481
cursor .rowcount = self .build_query ().delete ()
516
482
return cursor
0 commit comments