Skip to content

Commit acc99da

Browse files
committed
handle the empty group by.
1 parent 931a0b4 commit acc99da

File tree

2 files changed

+92
-51
lines changed

2 files changed

+92
-51
lines changed

django_mongodb/compiler.py

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from django.db import DatabaseError, IntegrityError, NotSupportedError
33
from django.db.models import Count, Expression
44
from django.db.models.aggregates import Aggregate
5-
from django.db.models.expressions import OrderBy, Value
5+
from django.db.models.expressions import Col, OrderBy, Value
66
from django.db.models.sql import compiler
7-
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR
7+
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE
8+
89
from django.utils.functional import cached_property
910

1011
from .base import Cursor
@@ -16,6 +17,62 @@ class SQLCompiler(compiler.SQLCompiler):
1617

1718
query_class = MongoQuery
1819

20+
def pre_sql_setup(self):
21+
super().pre_sql_setup()
22+
self.annotations = {}
23+
group = {}
24+
group_expressions = set()
25+
aggregation_idx = 1
26+
for target, expr in self.query.annotation_select.items():
27+
if not expr.contains_aggregate:
28+
result_expr = expr
29+
else:
30+
replacements = {}
31+
for sub_expr in self._get_aggregate_expressions(expr):
32+
alias = f"__aggregation{aggregation_idx}"
33+
group[alias] = sub_expr.as_mql(self, self.connection)
34+
aggregation_idx += 1
35+
column_target = expr.output_field.__class__()
36+
column_target.set_attributes_from_name(alias)
37+
replacements[sub_expr] = Col(self.collection_name, column_target)
38+
result_expr = expr.replace_expressions(replacements)
39+
40+
self.annotations[target] = result_expr
41+
if group:
42+
"""
43+
order_by = self.get_order_by()
44+
for expr, (_, _, is_ref) in order_by:
45+
# Skip references to the SELECT clause, as all expressions in
46+
# the SELECT clause are already part of the GROUP BY.
47+
if not is_ref:
48+
group_expressions |= set(expr.get_group_by_cols())
49+
having_group_by = self.having.get_group_by_cols() if self.having else ()
50+
for expr in having_group_by:
51+
group_expressions.add(expr)
52+
"""
53+
54+
ids = (
55+
None
56+
if not group_expressions
57+
else {
58+
col.target.column: col.as_mql(self, self.connection)
59+
for col in group_expressions
60+
}
61+
)
62+
group["_id"] = ids
63+
64+
pipeline = [{"$group": group}]
65+
if ids:
66+
pipeline.append(
67+
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
68+
)
69+
if "_id" not in ids:
70+
pipeline.append({"$unSet": "$_id"})
71+
72+
self._group_pipeline = pipeline
73+
else:
74+
self._group_pipeline = None
75+
1976
def execute_sql(
2077
self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
2178
):
@@ -33,11 +90,13 @@ def execute_sql(
3390
except EmptyResultSet:
3491
return iter([]) if result_type == MULTI else None
3592

36-
return (
37-
(self._make_result(row, columns) for row in query.fetch())
38-
if result_type == MULTI
39-
else self._make_result(next(query.fetch()), columns)
40-
)
93+
if result_type == MULTI:
94+
return (self._make_result(row, columns) for row in query.fetch())
95+
96+
try:
97+
return self._make_result(next(query.fetch()), columns)
98+
except StopIteration:
99+
return None
41100

42101
def results_iter(
43102
self,
@@ -64,7 +123,7 @@ def results_iter(
64123
return rows
65124

66125
def has_results(self):
67-
return bool(self.get_count(check_exists=True))
126+
return bool(self.execute_sql(SINGLE))
68127

69128
def _make_result(self, entity, columns):
70129
"""
@@ -143,9 +202,9 @@ def build_query(self, columns=None):
143202
"""Check if the query is supported and prepare a MongoQuery."""
144203
self.check_query()
145204
query = self.query_class(self)
146-
query.project_fields = self.get_project_fields(columns)
147-
query.lookup_pipeline = self.get_lookup_pipeline()
148205
query.aggregation_stage = self.get_aggregation_pipeline()
206+
query.lookup_pipeline = self.get_lookup_pipeline()
207+
query.project_fields = self.get_project_fields(columns)
149208
try:
150209
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
151210
except FullResultSet:
@@ -185,7 +244,7 @@ def project_field(column):
185244

186245
return (
187246
tuple(map(project_field, columns))
188-
+ tuple(self.query.annotation_select.items())
247+
+ tuple(self.annotations.items())
189248
+ tuple(map(project_field, related_columns))
190249
)
191250

@@ -250,52 +309,34 @@ def get_lookup_pipeline(self):
250309
result += self.query.alias_map[alias].as_mql(self, self.connection)
251310
return result
252311

253-
def get_aggregation_pipeline(self):
254-
pipeline = None
255-
if any(isinstance(a, Aggregate) for a in self.query.annotations.values()):
256-
result = {}
257-
# self.get_group_by(self.select, [])
258-
for alias, annotation in self.query.annotation_select.items():
259-
value = annotation.as_mql(self, self.connection)
260-
if isinstance(value, list):
261-
value = value[0]
262-
result[alias] = value
263-
264-
expressions = set()
265-
for expr, *_ in self.select:
266-
expressions |= set(expr.get_group_by_cols())
267-
order_by = self.get_order_by()
268-
for expr, (_, _, is_ref) in order_by:
269-
# Skip references to the SELECT clause, as all expressions in
270-
# the SELECT clause are already part of the GROUP BY.
271-
if not is_ref:
272-
expressions |= set(expr.get_group_by_cols())
273-
having_group_by = self.having.get_group_by_cols() if self.having else ()
274-
for expr in having_group_by:
275-
expressions.add(expr)
276-
277-
ids = (
278-
None
279-
if not expressions
280-
else {col.target.column: col.as_mql(self, self.connection) for col in expressions}
281-
)
282-
result["_id"] = ids
283-
284-
pipeline = [{"$group": result}]
285-
if ids:
286-
pipeline.append(
287-
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
312+
def _get_aggregate_expressions2(self, expr):
313+
stack = [(None, expr)]
314+
while stack:
315+
parent, expr = stack.pop()
316+
if isinstance(expr, Aggregate):
317+
yield parent
318+
elif hasattr(expr, "get_source_expressions"):
319+
stack.extend(
320+
[((expr, idx), se) for idx, se in enumerate(expr.get_source_expressions())]
288321
)
289-
if "_id" not in ids:
290-
pipeline.append({"$unSet": "$_id"})
291322

292-
return pipeline
323+
def _get_aggregate_expressions(self, expr):
324+
stack = [expr]
325+
while stack:
326+
expr = stack.pop()
327+
if isinstance(expr, Aggregate):
328+
yield expr
329+
elif hasattr(expr, "get_source_expressions"):
330+
stack.extend(expr.get_source_expressions())
331+
332+
def get_aggregation_pipeline(self):
333+
return self._group_pipeline
293334

294335
def get_project_fields(self, columns=None):
295336
fields = {}
296337
for name, expr in columns or []:
297338
try:
298-
column = name if isinstance(expr, Aggregate) else expr.target.column
339+
column = expr.target.column
299340
except AttributeError:
300341
# Generate the MQL for an annotation.
301342
try:

django_mongodb/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Count: "sum",
4040
Max: "max",
4141
Min: "min",
42-
StdDev: "stddev",
42+
StdDev: "stdDevPop",
4343
Sum: "sum",
4444
Variance: "stdDevPop",
4545
}

0 commit comments

Comments
 (0)