Skip to content

Commit ae1f5d5

Browse files
WaVEVtimgraham
authored andcommitted
handle the empty group by.
1 parent f00d516 commit ae1f5d5

File tree

2 files changed

+83
-45
lines changed

2 files changed

+83
-45
lines changed

django_mongodb/compiler.py

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db import DatabaseError, IntegrityError, NotSupportedError
55
from django.db.models import Count, Expression
66
from django.db.models.aggregates import Aggregate
7-
from django.db.models.expressions import OrderBy, Value
7+
from django.db.models.expressions import Col, OrderBy, Value
88
from django.db.models.sql import compiler
99
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE
1010
from django.utils.functional import cached_property
@@ -18,6 +18,62 @@ class SQLCompiler(compiler.SQLCompiler):
1818

1919
query_class = MongoQuery
2020

21+
def pre_sql_setup(self):
22+
super().pre_sql_setup()
23+
self.annotations = {}
24+
group = {}
25+
group_expressions = set()
26+
aggregation_idx = 1
27+
for target, expr in self.query.annotation_select.items():
28+
if not expr.contains_aggregate:
29+
result_expr = expr
30+
else:
31+
replacements = {}
32+
for sub_expr in self._get_aggregate_expressions(expr):
33+
alias = f"__aggregation{aggregation_idx}"
34+
group[alias] = sub_expr.as_mql(self, self.connection)
35+
aggregation_idx += 1
36+
column_target = expr.output_field.__class__()
37+
column_target.set_attributes_from_name(alias)
38+
replacements[sub_expr] = Col(self.collection_name, column_target)
39+
result_expr = expr.replace_expressions(replacements)
40+
41+
self.annotations[target] = result_expr
42+
if group:
43+
"""
44+
order_by = self.get_order_by()
45+
for expr, (_, _, is_ref) in order_by:
46+
# Skip references to the SELECT clause, as all expressions in
47+
# the SELECT clause are already part of the GROUP BY.
48+
if not is_ref:
49+
group_expressions |= set(expr.get_group_by_cols())
50+
having_group_by = self.having.get_group_by_cols() if self.having else ()
51+
for expr in having_group_by:
52+
group_expressions.add(expr)
53+
"""
54+
55+
ids = (
56+
None
57+
if not group_expressions
58+
else {
59+
col.target.column: col.as_mql(self, self.connection)
60+
for col in group_expressions
61+
}
62+
)
63+
group["_id"] = ids
64+
65+
pipeline = [{"$group": group}]
66+
if ids:
67+
pipeline.append(
68+
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
69+
)
70+
if "_id" not in ids:
71+
pipeline.append({"$unSet": "$_id"})
72+
73+
self._group_pipeline = pipeline
74+
else:
75+
self._group_pipeline = None
76+
2177
def execute_sql(
2278
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
2379
):
@@ -85,7 +141,7 @@ def results_iter(
85141
return rows
86142

87143
def has_results(self):
88-
return bool(self.get_count(check_exists=True))
144+
return bool(self.execute_sql(SINGLE))
89145

90146
def _make_result(self, entity, columns):
91147
"""
@@ -172,9 +228,9 @@ def build_query(self, columns=None):
172228
"""Check if the query is supported and prepare a MongoQuery."""
173229
self.check_query()
174230
query = self.query_class(self)
175-
query.project_fields = self.get_project_fields(columns)
176-
query.lookup_pipeline = self.get_lookup_pipeline()
177231
query.aggregation_stage = self.get_aggregation_pipeline()
232+
query.lookup_pipeline = self.get_lookup_pipeline()
233+
query.project_fields = self.get_project_fields(columns)
178234
try:
179235
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
180236
except FullResultSet:
@@ -214,7 +270,7 @@ def project_field(column):
214270

215271
return (
216272
tuple(map(project_field, columns))
217-
+ tuple(self.query.annotation_select.items())
273+
+ tuple(self.annotations.items())
218274
+ tuple(map(project_field, related_columns))
219275
)
220276

@@ -279,52 +335,34 @@ def get_lookup_pipeline(self):
279335
result += self.query.alias_map[alias].as_mql(self, self.connection)
280336
return result
281337

282-
def get_aggregation_pipeline(self):
283-
pipeline = None
284-
if any(isinstance(a, Aggregate) for a in self.query.annotations.values()):
285-
result = {}
286-
# self.get_group_by(self.select, [])
287-
for alias, annotation in self.query.annotation_select.items():
288-
value = annotation.as_mql(self, self.connection)
289-
if isinstance(value, list):
290-
value = value[0]
291-
result[alias] = value
292-
293-
expressions = set()
294-
for expr, *_ in self.select:
295-
expressions |= set(expr.get_group_by_cols())
296-
order_by = self.get_order_by()
297-
for expr, (_, _, is_ref) in order_by:
298-
# Skip references to the SELECT clause, as all expressions in
299-
# the SELECT clause are already part of the GROUP BY.
300-
if not is_ref:
301-
expressions |= set(expr.get_group_by_cols())
302-
having_group_by = self.having.get_group_by_cols() if self.having else ()
303-
for expr in having_group_by:
304-
expressions.add(expr)
305-
306-
ids = (
307-
None
308-
if not expressions
309-
else {col.target.column: col.as_mql(self, self.connection) for col in expressions}
310-
)
311-
result["_id"] = ids
312-
313-
pipeline = [{"$group": result}]
314-
if ids:
315-
pipeline.append(
316-
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
338+
def _get_aggregate_expressions2(self, expr):
339+
stack = [(None, expr)]
340+
while stack:
341+
parent, expr = stack.pop()
342+
if isinstance(expr, Aggregate):
343+
yield parent
344+
elif hasattr(expr, "get_source_expressions"):
345+
stack.extend(
346+
[((expr, idx), se) for idx, se in enumerate(expr.get_source_expressions())]
317347
)
318-
if "_id" not in ids:
319-
pipeline.append({"$unSet": "$_id"})
320348

321-
return pipeline
349+
def _get_aggregate_expressions(self, expr):
350+
stack = [expr]
351+
while stack:
352+
expr = stack.pop()
353+
if isinstance(expr, Aggregate):
354+
yield expr
355+
elif hasattr(expr, "get_source_expressions"):
356+
stack.extend(expr.get_source_expressions())
357+
358+
def get_aggregation_pipeline(self):
359+
return self._group_pipeline
322360

323361
def get_project_fields(self, columns=None):
324362
fields = {}
325363
for name, expr in columns or []:
326364
try:
327-
column = name if isinstance(expr, Aggregate) else expr.target.column
365+
column = expr.target.column
328366
except AttributeError:
329367
# Generate the MQL for an annotation.
330368
try:

django_mongodb/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Count: "sum",
4040
Max: "max",
4141
Min: "min",
42-
StdDev: "stddev",
42+
StdDev: "stdDevPop",
4343
Sum: "sum",
4444
Variance: "stdDevPop",
4545
}

0 commit comments

Comments
 (0)