Skip to content

Commit 93ee4b8

Browse files
committed
Edits.
1 parent cbe87a7 commit 93ee4b8

File tree

1 file changed

+36
-13
lines changed

1 file changed

+36
-13
lines changed

django_mongodb/compiler.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
from collections import defaultdict
23

34
from django.core.exceptions import EmptyResultSet, FullResultSet
45
from django.db import DatabaseError, IntegrityError, NotSupportedError
@@ -19,7 +20,7 @@ class SQLCompiler(compiler.SQLCompiler):
1920
"""Base class for all Mongo compilers."""
2021

2122
query_class = MongoQuery
22-
SEPARATOR = "10__MESSI__3"
23+
GROUP_SEPARATOR = "___"
2324

2425
def __init__(self, *args, **kwargs):
2526
super().__init__(*args, **kwargs)
@@ -35,7 +36,7 @@ def _get_group_alias_column(self, col, annotation_group_idx):
3536
col = col_expr
3637
if self.collection_name == col.alias:
3738
return col.target.column, replacement
38-
return f"{col.alias}{self.SEPARATOR}{col.target.column}", replacement
39+
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement
3940

4041
def _get_column_from_expression(self, expr, alias):
4142
"""Get column target from expression."""
@@ -45,7 +46,21 @@ def _get_column_from_expression(self, expr, alias):
4546
return Col(self.collection_name, column_target)
4647

4748
def _prepare_expressions_for_pipeline(self, expression, target, count):
48-
"""Prepare expressions for the aggregation pipeline."""
49+
"""
50+
Prepare expressions for the aggregation pipeline.
51+
52+
This function handles the computation of aggregation functions used by various expressions.
53+
It separates and creates intermediate columns, and replaces nodes
54+
to simulate a group by operation.
55+
56+
In MongoDB, the `$group` stage does not allow operations over the aggregator
57+
(e.g., `COALESCE(AVG(field), 3)`). However, it does support operations inside
58+
the aggregation (e.g., `AVG(number * 2)`).
59+
This function manages first cases by splitting the computation into stages:
60+
it computes the aggregation function first, then applies additional operations
61+
in a subsequent stage by replacing the aggregate expressions
62+
with new columns prefixed by `__aggregation`.
63+
"""
4964
replacements = {}
5065
group = {}
5166
for sub_expr in self._get_aggregate_expressions(expression):
@@ -150,19 +165,17 @@ def _build_group_pipeline(self, ids, group):
150165
else:
151166
group["_id"] = ids
152167
pipeline.append({"$group": group})
153-
sets = {}
168+
projected_fields = defaultdict(dict)
154169
for key in ids:
155170
value = f"$_id.{key}"
156-
if self.SEPARATOR in key:
157-
subtable, field = key.split(self.SEPARATOR)
158-
if subtable not in sets:
159-
sets[subtable] = {}
160-
sets[subtable][field] = value
171+
if self.GROUP_SEPARATOR in key:
172+
subtable, field = key.split(self.GROUP_SEPARATOR)
173+
projected_fields[subtable][field] = value
161174
else:
162-
sets[key] = value
175+
projected_fields[key] = value
163176

164-
pipeline.append({"$addFields": sets})
165-
if "_id" not in sets:
177+
pipeline.append({"$addFields": projected_fields})
178+
if "_id" not in projected_fields:
166179
pipeline.append({"$unset": "_id"})
167180

168181
return pipeline
@@ -320,7 +333,8 @@ def build_query(self, columns=None):
320333
query.order_by(self._get_ordering())
321334
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
322335
try:
323-
where = getattr(self, "where", self.query.where)
336+
where = self.get_where()
337+
# where = getattr(self, "where", self.query.where)
324338
query.mongo_query = (
325339
{"$expr": where.as_mql(self, self.connection)} if where is not None else {}
326340
)
@@ -466,6 +480,9 @@ def get_project_fields(self, columns=None, ordering=None):
466480
fields[column] = 1
467481
return fields
468482

483+
def get_where(self):
484+
return self.where
485+
469486

470487
class SQLInsertCompiler(SQLCompiler):
471488
def execute_sql(self, returning_fields=None):
@@ -511,6 +528,9 @@ def check_query(self):
511528
"Cannot use QuerySet.delete() when querying across multiple collections on MongoDB."
512529
)
513530

531+
def get_where(self):
532+
return self.query.where
533+
514534

515535
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
516536
def execute_sql(self, result_type):
@@ -572,6 +592,9 @@ def check_query(self):
572592
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
573593
)
574594

595+
def get_where(self):
596+
return self.query.where
597+
575598

576599
class SQLAggregateCompiler(SQLCompiler):
577600
def build_query(self, columns=None):

0 commit comments

Comments
 (0)