1
1
import itertools
2
+ from collections import defaultdict
2
3
3
4
from django .core .exceptions import EmptyResultSet , FullResultSet
4
5
from django .db import DatabaseError , IntegrityError , NotSupportedError
@@ -19,7 +20,7 @@ class SQLCompiler(compiler.SQLCompiler):
19
20
"""Base class for all Mongo compilers."""
20
21
21
22
query_class = MongoQuery
22
- SEPARATOR = "10__MESSI__3 "
23
+ GROUP_SEPARATOR = "___ "
23
24
24
25
def __init__ (self , * args , ** kwargs ):
25
26
super ().__init__ (* args , ** kwargs )
@@ -35,7 +36,7 @@ def _get_group_alias_column(self, col, annotation_group_idx):
35
36
col = col_expr
36
37
if self .collection_name == col .alias :
37
38
return col .target .column , replacement
38
- return f"{ col .alias } { self .SEPARATOR } { col .target .column } " , replacement
39
+ return f"{ col .alias } { self .GROUP_SEPARATOR } { col .target .column } " , replacement
39
40
40
41
def _get_column_from_expression (self , expr , alias ):
41
42
"""Get column target from expression."""
@@ -45,7 +46,21 @@ def _get_column_from_expression(self, expr, alias):
45
46
return Col (self .collection_name , column_target )
46
47
47
48
def _prepare_expressions_for_pipeline (self , expression , target , count ):
48
- """Prepare expressions for the aggregation pipeline."""
49
+ """
50
+ Prepare expressions for the aggregation pipeline.
51
+
52
+ This function handles the computation of aggregation functions used by various expressions.
53
+ It separates and creates intermediate columns, and replaces nodes
54
+ to simulate a group by operation.
55
+
56
+ In MongoDB, the `$group` stage does not allow operations over the aggregator
57
+ (e.g., `COALESCE(AVG(field), 3)`). However, it does support operations inside
58
+ the aggregation (e.g., `AVG(number * 2)`).
59
+ This function manages first cases by splitting the computation into stages:
60
+ it computes the aggregation function first, then applies additional operations
61
+ in a subsequent stage by replacing the aggregate expressions
62
+ with new columns prefixed by `__aggregation`.
63
+ """
49
64
replacements = {}
50
65
group = {}
51
66
for sub_expr in self ._get_aggregate_expressions (expression ):
@@ -150,19 +165,17 @@ def _build_group_pipeline(self, ids, group):
150
165
else :
151
166
group ["_id" ] = ids
152
167
pipeline .append ({"$group" : group })
153
- sets = {}
168
+ projected_fields = defaultdict ( dict )
154
169
for key in ids :
155
170
value = f"$_id.{ key } "
156
- if self .SEPARATOR in key :
157
- subtable , field = key .split (self .SEPARATOR )
158
- if subtable not in sets :
159
- sets [subtable ] = {}
160
- sets [subtable ][field ] = value
171
+ if self .GROUP_SEPARATOR in key :
172
+ subtable , field = key .split (self .GROUP_SEPARATOR )
173
+ projected_fields [subtable ][field ] = value
161
174
else :
162
- sets [key ] = value
175
+ projected_fields [key ] = value
163
176
164
- pipeline .append ({"$addFields" : sets })
165
- if "_id" not in sets :
177
+ pipeline .append ({"$addFields" : projected_fields })
178
+ if "_id" not in projected_fields :
166
179
pipeline .append ({"$unset" : "_id" })
167
180
168
181
return pipeline
@@ -320,7 +333,8 @@ def build_query(self, columns=None):
320
333
query .order_by (self ._get_ordering ())
321
334
query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
322
335
try :
323
- where = getattr (self , "where" , self .query .where )
336
+ where = self .get_where ()
337
+ # where = getattr(self, "where", self.query.where)
324
338
query .mongo_query = (
325
339
{"$expr" : where .as_mql (self , self .connection )} if where is not None else {}
326
340
)
@@ -466,6 +480,9 @@ def get_project_fields(self, columns=None, ordering=None):
466
480
fields [column ] = 1
467
481
return fields
468
482
483
+ def get_where (self ):
484
+ return self .where
485
+
469
486
470
487
class SQLInsertCompiler (SQLCompiler ):
471
488
def execute_sql (self , returning_fields = None ):
@@ -511,6 +528,9 @@ def check_query(self):
511
528
"Cannot use QuerySet.delete() when querying across multiple collections on MongoDB."
512
529
)
513
530
531
+ def get_where (self ):
532
+ return self .query .where
533
+
514
534
515
535
class SQLUpdateCompiler (compiler .SQLUpdateCompiler , SQLCompiler ):
516
536
def execute_sql (self , result_type ):
@@ -572,6 +592,9 @@ def check_query(self):
572
592
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
573
593
)
574
594
595
+ def get_where (self ):
596
+ return self .query .where
597
+
575
598
576
599
class SQLAggregateCompiler (SQLCompiler ):
577
600
def build_query (self , columns = None ):
0 commit comments