1
- from itertools import chain
1
+ import itertools
2
2
3
3
from django .core .exceptions import EmptyResultSet , FullResultSet
4
4
from django .db import DatabaseError , IntegrityError , NotSupportedError
@@ -19,22 +19,33 @@ class SQLCompiler(compiler.SQLCompiler):
19
19
"""Base class for all Mongo compilers."""
20
20
21
21
query_class = MongoQuery
22
- _group_pipeline = None
23
- aggregation_idx = 0
24
-
25
- def _get_colum_from_expression (self , expr , alias ):
22
+ SEPARATOR = "10__MESSI__3"
23
+
24
+ def _get_group_alias_column (self , col , annotation_group_idx ):
25
+ """Generate alias and replacement for group columns."""
26
+ replacement = None
27
+ if not isinstance (col , Col ):
28
+ alias = f"__annotation_group{ next (annotation_group_idx )} "
29
+ col_expr = self ._get_column_from_expression (col , alias )
30
+ replacement = col_expr
31
+ col = col_expr
32
+ if self .collection_name == col .alias :
33
+ return col .target .column , replacement
34
+ return f"{ col .alias } { self .SEPARATOR } { col .target .column } " , replacement
35
+
36
+ def _get_column_from_expression (self , expr , alias ):
37
+ """Get column target from expression."""
26
38
column_target = expr .output_field .__class__ ()
27
39
column_target .db_column = alias
28
40
column_target .set_attributes_from_name (alias )
29
41
return Col (self .collection_name , column_target )
30
42
31
- def _prepare_expressions_for_pipeline (self , expression , target ):
43
+ def _prepare_expressions_for_pipeline (self , expression , target , count ):
44
+ """Prepare expressions for the MongoDB aggregation pipeline."""
32
45
replacements = {}
33
46
group = {}
34
47
for sub_expr in self ._get_aggregate_expressions (expression ):
35
- alias = f"__aggregation{ self .aggregation_idx } " if sub_expr != expression else target
36
- self .aggregation_idx += 1
37
-
48
+ alias = f"__aggregation{ next (count )} " if sub_expr != expression else target
38
49
column_target = sub_expr .output_field .__class__ ()
39
50
column_target .db_column = alias
40
51
column_target .set_attributes_from_name (alias )
@@ -57,127 +68,109 @@ def _prepare_expressions_for_pipeline(self, expression, target):
57
68
replacements [sub_expr ] = replacing_expr
58
69
return replacements , group
59
70
60
- @staticmethod
61
- def _random_separtor ():
62
- import random
63
- import string
64
-
65
- size = 6
66
- chars = string .ascii_uppercase + string .digits
67
- return "" .join (random .choice (chars ) for _ in range (size )) # noqa: S311
68
-
69
- def pre_sql_setup (self , with_col_aliases = False ):
70
- pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
71
- self .annotations = {}
71
+ def _prepare_annotations_for_group_pipeline (self ):
72
+ """Prepare annotations for the MongoDB aggregation pipeline."""
73
+ replacements = {}
72
74
group = {}
73
- group_expressions = set ()
74
- all_replacements = {}
75
- self .aggregation_idx = 0
75
+ count = itertools .count (start = 1 )
76
76
for target , expr in self .query .annotation_select .items ():
77
77
if expr .contains_aggregate :
78
- replacements , expr_group = self ._prepare_expressions_for_pipeline (expr , target )
79
- all_replacements .update (replacements )
78
+ new_replacements , expr_group = self ._prepare_expressions_for_pipeline (
79
+ expr , target , count
80
+ )
81
+ replacements .update (new_replacements )
80
82
group .update (expr_group )
81
- group_expressions |= set (expr .get_group_by_cols ())
82
83
83
84
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
84
- self .having , None
85
+ self .having , None , count
85
86
)
86
- all_replacements .update (having_replacements )
87
+ replacements .update (having_replacements )
87
88
group .update (having_group )
89
+ return group , replacements
88
90
89
- if group or self .query .group_by :
90
- order_by = self .get_order_by ()
91
- for expr , (_ , _ , is_ref ) in order_by :
92
- # Skip references to the SELECT clause, as all expressions in
93
- # the SELECT clause are already part of the GROUP BY.
94
- if not is_ref :
95
- group_expressions |= set (expr .get_group_by_cols ())
96
-
97
- for expr , * _ in self .select :
91
+ def _get_group_id_expressions (self ):
92
+ """Generate group ID expressions for the aggregation pipeline."""
93
+ group_expressions = set ()
94
+ replacements = {}
95
+ order_by = self .get_order_by ()
96
+ for expr , (_ , _ , is_ref ) in order_by :
97
+ if not is_ref :
98
98
group_expressions |= set (expr .get_group_by_cols ())
99
99
100
- having_group_by = self .having .get_group_by_cols () if self .having else ()
101
- for expr in having_group_by :
102
- group_expressions .add (expr )
103
- if isinstance (self .query .group_by , tuple | list ):
104
- group_expressions |= set (self .query .group_by )
105
- elif self .query .group_by is None :
106
- group_expressions = set ()
100
+ for expr , * _ in self .select :
101
+ group_expressions |= set (expr .get_group_by_cols ())
107
102
108
- all_strings = "" .join (
109
- str (col .as_mql (self , self .connection )) for col in group_expressions
110
- )
103
+ having_group_by = self .having .get_group_by_cols () if self .having else ()
104
+ for expr in having_group_by :
105
+ group_expressions .add (expr )
106
+ if isinstance (self .query .group_by , tuple | list ):
107
+ group_expressions |= set (self .query .group_by )
108
+ elif self .query .group_by is None :
109
+ group_expressions = set ()
111
110
112
- while True :
113
- random_string = self ._random_separtor ()
114
- if random_string not in all_strings :
115
- break
116
- SEPARATOR = f"__{ random_string } __"
117
-
118
- annotation_group_idx = 0
119
-
120
- def _ccc (col ):
121
- nonlocal annotation_group_idx
122
-
123
- if not isinstance (col , Col ):
124
- annotation_group_idx += 1
125
- alias = f"__annotation_group_{ annotation_group_idx } "
126
- col_expr = self ._get_colum_from_expression (col , alias )
127
- all_replacements [col ] = col_expr
128
- col = col_expr
129
- if self .collection_name == col .alias :
130
- return col .target .column
131
- return f"{ col .alias } { SEPARATOR } { col .target .column } "
132
-
133
- ids = (
134
- None
135
- if not group_expressions
136
- else {
137
- _ccc (col ): col .as_mql (self , self .connection )
138
- # expression aren't needed in the group by clouse ()
139
- for col in group_expressions
140
- }
141
- )
142
- self .annotations = {
143
- target : expr .replace_expressions (all_replacements )
144
- for target , expr in self .query .annotation_select .items ()
145
- }
146
- pipeline = []
147
- if not ids :
148
- group ["_id" ] = None
149
- pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
150
- pipeline .append (
151
- {
152
- "$addFields" : {
153
- key : {
154
- "$getField" : {
155
- "input" : {"$arrayElemAt" : ["$group" , 0 ]},
156
- "field" : key ,
157
- }
111
+ if not group_expressions :
112
+ ids = None
113
+ else :
114
+ annotation_group_idx = itertools .count (start = 1 )
115
+ ids = {}
116
+ for col in group_expressions :
117
+ alias , replacement = self ._get_group_alias_column (col , annotation_group_idx )
118
+ ids [alias ] = col .as_mql (self , self .connection )
119
+ if replacement is not None :
120
+ replacements [col ] = replacement
121
+
122
+ return ids , replacements
123
+
124
+ def _build_group_pipeline (self , ids , group ):
125
+ """Build the aggregation pipeline for grouping."""
126
+ pipeline = []
127
+ if not ids :
128
+ group ["_id" ] = None
129
+ pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
130
+ pipeline .append (
131
+ {
132
+ "$addFields" : {
133
+ key : {
134
+ "$getField" : {
135
+ "input" : {"$arrayElemAt" : ["$group" , 0 ]},
136
+ "field" : key ,
158
137
}
159
- for key in group
160
138
}
139
+ for key in group
161
140
}
162
- )
163
- else :
164
- group ["_id" ] = ids
165
- pipeline .append ({"$group" : group })
166
- sets = {}
167
- for key in ids :
168
- value = f"$_id.{ key } "
169
- if SEPARATOR in key :
170
- subtable , field = key .split (SEPARATOR )
171
- if subtable not in sets :
172
- sets [subtable ] = {}
173
- sets [subtable ][field ] = value
174
- else :
175
- sets [key ] = value
176
-
177
- pipeline .append ({"$addFields" : sets })
178
- if "_id" not in sets :
179
- pipeline .append ({"$unset" : "_id" })
141
+ }
142
+ )
143
+ else :
144
+ group ["_id" ] = ids
145
+ pipeline .append ({"$group" : group })
146
+ sets = {}
147
+ for key in ids :
148
+ value = f"$_id.{ key } "
149
+ if self .SEPARATOR in key :
150
+ subtable , field = key .split (self .SEPARATOR )
151
+ if subtable not in sets :
152
+ sets [subtable ] = {}
153
+ sets [subtable ][field ] = value
154
+ else :
155
+ sets [key ] = value
156
+
157
+ pipeline .append ({"$addFields" : sets })
158
+ if "_id" not in sets :
159
+ pipeline .append ({"$unset" : "_id" })
180
160
161
+ return pipeline
162
+
163
+ def pre_sql_setup (self , with_col_aliases = False ):
164
+ pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
165
+ group , all_replacements = self ._prepare_annotations_for_group_pipeline ()
166
+
167
+ # The query.group_by is either None (no GROUP BY at all), True
168
+ # (group by select fields), or a list of expressions to be added
169
+ # to the group by.
170
+ if group or self .query .group_by :
171
+ ids , replacements = self ._get_group_id_expressions ()
172
+ all_replacements .update (replacements )
173
+ pipeline = self ._build_group_pipeline (ids , group )
181
174
if self .having :
182
175
pipeline .append (
183
176
{
@@ -188,7 +181,6 @@ def _ccc(col):
188
181
}
189
182
}
190
183
)
191
-
192
184
self ._group_pipeline = pipeline
193
185
else :
194
186
self ._group_pipeline = None
@@ -203,7 +195,6 @@ def _ccc(col):
203
195
def execute_sql (
204
196
self , result_type = MULTI , chunked_fetch = False , chunk_size = GET_ITERATOR_CHUNK_SIZE
205
197
):
206
- # QuerySet.count()
207
198
self .pre_sql_setup ()
208
199
columns = self .get_columns ()
209
200
try :
@@ -256,7 +247,7 @@ def results_iter(
256
247
257
248
fields = [s [0 ] for s in self .select [0 : self .col_count ]]
258
249
converters = self .get_converters (fields )
259
- rows = chain .from_iterable (results )
250
+ rows = itertools . chain .from_iterable (results )
260
251
if converters :
261
252
rows = self .apply_converters (rows , converters )
262
253
if tuple_expected :
@@ -320,34 +311,6 @@ def check_query(self):
320
311
if any (key .startswith ("_prefetch_related_" ) for key in self .query .extra ):
321
312
raise NotSupportedError ("QuerySet.prefetch_related() is not supported on MongoDB." )
322
313
raise NotSupportedError ("QuerySet.extra() is not supported on MongoDB." )
323
- if any (
324
- isinstance (a , Aggregate ) and not isinstance (a , Count )
325
- for a in self .query .annotations .values ()
326
- ):
327
- # raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.")
328
- pass
329
-
330
- def get_count (self , check_exists = False ):
331
- """
332
- Count objects matching the current filters / constraints.
333
-
334
- If `check_exists` is True, only check if any object matches.
335
- """
336
- kwargs = {}
337
- # If this query is sliced, the limits will be set on the subquery.
338
- inner_query = getattr (self .query , "inner_query" , None )
339
- low_mark = inner_query .low_mark if inner_query else 0
340
- high_mark = inner_query .high_mark if inner_query else None
341
- if low_mark > 0 :
342
- kwargs ["skip" ] = low_mark
343
- if check_exists :
344
- kwargs ["limit" ] = 1
345
- elif high_mark is not None :
346
- kwargs ["limit" ] = high_mark - low_mark
347
- try :
348
- return self .build_query ().count (** kwargs )
349
- except EmptyResultSet :
350
- return 0
351
314
352
315
def build_query (self , columns = None ):
353
316
"""Check if the query is supported and prepare a MongoQuery."""
@@ -540,6 +503,7 @@ def insert(self, docs, returning_fields=None):
540
503
541
504
class SQLDeleteCompiler (compiler .SQLDeleteCompiler , SQLCompiler ):
542
505
def execute_sql (self , result_type = MULTI ):
506
+ self .pre_sql_setup ()
543
507
cursor = Cursor ()
544
508
cursor .rowcount = self .build_query ().delete ()
545
509
return cursor
0 commit comments