Skip to content

Commit d5c3f7a

Browse files
committed
Wrapping group when no group by.
1 parent 23a892a commit d5c3f7a

File tree

4 files changed

+116
-19
lines changed

4 files changed

+116
-19
lines changed

django_mongodb/compiler.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,18 @@ class SQLCompiler(compiler.SQLCompiler):
1818
query_class = MongoQuery
1919
_group_pipeline = None
2020

21+
@staticmethod
22+
def _random_separtor():
23+
import random
24+
import string
25+
26+
size = 6
27+
chars = string.ascii_uppercase + string.digits
28+
return "".join(random.choice(chars) for _ in range(size)) # noqa: S311
29+
2130
def pre_sql_setup(self, with_col_aliases=False):
2231
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2332
self.annotations = {}
24-
# mongo_having = self.having.copy() if self.having else None
2533
group = {}
2634
group_expressions = set()
2735
aggregation_idx = 1
@@ -37,13 +45,14 @@ def pre_sql_setup(self, with_col_aliases=False):
3745
aggregation_idx += 1
3846
else:
3947
alias = target
40-
group_expressions |= set(sub_expr.get_group_by_cols())
4148
group[alias] = sub_expr.as_mql(self, self.connection)
4249
column_target = expr.output_field.__class__()
50+
column_target.db_column = alias
4351
column_target.set_attributes_from_name(alias)
4452
replacements[sub_expr] = Col(self.collection_name, column_target)
4553
result_expr = expr.replace_expressions(replacements)
4654
all_replacements.update(replacements)
55+
group_expressions |= set(expr.get_group_by_cols())
4756
self.annotations[target] = result_expr
4857
if group:
4958
order_by = self.get_order_by()
@@ -58,24 +67,69 @@ def pre_sql_setup(self, with_col_aliases=False):
5867
if isinstance(self.query.group_by, tuple | list):
5968
group_expressions |= set(self.query.group_by)
6069

70+
all_strings = "".join(
71+
str(col.as_mql(self, self.connection)) for col in group_expressions
72+
)
73+
74+
while True:
75+
random_string = self._random_separtor()
76+
if random_string not in all_strings:
77+
break
78+
SEPARATOR = f"__{random_string}__"
79+
80+
def _ccc(col):
81+
if self.collection_name == col.alias:
82+
return col.target.column
83+
return f"{col.alias}{SEPARATOR}{col.target.column}"
84+
6185
ids = (
6286
None
6387
if not group_expressions
6488
else {
65-
col.target.column: col.as_mql(self, self.connection)
89+
_ccc(col): col.as_mql(self, self.connection)
6690
# expression aren't needed in the group by clouse ()
6791
for col in group_expressions
6892
if isinstance(col, Col)
6993
}
7094
)
71-
group["_id"] = ids
72-
pipeline = [{"$group": group}]
73-
if ids:
95+
pipeline = []
96+
if ids is None:
97+
group["_id"] = None
98+
pipeline.append({"$facet": {"group": [{"$group": group}]}})
7499
pipeline.append(
75-
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
100+
{
101+
"$project": {
102+
key: {
103+
"$getField": {
104+
"input": {"$arrayElemAt": ["$group", 0]},
105+
"field": key,
106+
}
107+
}
108+
for key in group
109+
}
110+
}
76111
)
77-
if "_id" not in ids:
112+
else:
113+
group["_id"] = ids
114+
pipeline.append({"$group": group})
115+
sets = {}
116+
for key in ids:
117+
value = f"$_id.{key}"
118+
if SEPARATOR in key:
119+
subtable, field = key.split(SEPARATOR)
120+
if subtable not in sets:
121+
sets[subtable] = {}
122+
sets[subtable][field] = value
123+
else:
124+
sets[key] = value
125+
126+
pipeline.append(
127+
# {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
128+
{"$addFields": sets}
129+
)
130+
if "_id" not in sets:
78131
pipeline.append({"$unset": "_id"})
132+
79133
if self.having:
80134
pipeline.append(
81135
{
@@ -224,14 +278,14 @@ def build_query(self, columns=None):
224278
query = self.query_class(self)
225279
query.aggregation_stage = self.get_aggregation_pipeline()
226280
query.lookup_pipeline = self.get_lookup_pipeline()
227-
query.project_fields = self.get_project_fields(columns)
281+
query.order_by(self._get_ordering())
282+
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
228283
try:
229284
query.mongo_query = (
230285
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
231286
)
232287
except FullResultSet:
233288
query.mongo_query = {}
234-
query.order_by(self._get_ordering())
235289
return query
236290

237291
def get_columns(self):
@@ -337,7 +391,7 @@ def _get_aggregate_expressions(self, expr):
337391
def get_aggregation_pipeline(self):
338392
return self._group_pipeline
339393

340-
def get_project_fields(self, columns=None):
394+
def get_project_fields(self, columns=None, ordering=None):
341395
fields = {}
342396
for name, expr in columns or []:
343397
try:
@@ -361,6 +415,10 @@ def get_project_fields(self, columns=None):
361415
if self.query.alias_refcount[alias] and self.collection_name != alias:
362416
fields[alias] = 1
363417

418+
for column, _ in ordering or []:
419+
if column not in fields:
420+
fields[column] = 1
421+
364422
return fields
365423

366424

@@ -461,7 +519,16 @@ def build_query(self, columns=None):
461519
elide_empty=self.elide_empty,
462520
)
463521
compiler.pre_sql_setup(with_col_aliases=False)
464-
query.sub_query = compiler.build_query()
522+
columns = (
523+
compiler.get_columns()
524+
if compiler.query.annotations or not compiler.query.default_cols
525+
else None
526+
)
527+
subquery = compiler.build_query(
528+
# Avoid $project (columns=None) if unneeded.
529+
columns
530+
)
531+
query.subquery = subquery
465532
return query
466533

467534
def _make_result(self, result, columns=None):

django_mongodb/features.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,24 @@ def django_test_expected_failures(self):
636636
"Test not applicable for MongoDB's SQLCompiler.": {
637637
"queries.test_iterator.QuerySetIteratorTests",
638638
},
639+
"skip agg": {
640+
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists",
641+
# Custom aggregations:
642+
"aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned",
643+
"aggregation.tests.AggregateTestCase.test_add_implementation",
644+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database",
645+
# SQL custom values.
646+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
647+
# No sql generate
648+
"aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned",
649+
"aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_pruned",
650+
# PI expression:
651+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_decimal_from_database",
652+
# check sql query performance
653+
"aggregation.tests.AggregateAnnotationPruningTests.test_referenced_aggregate_annotation_kept",
654+
# Using a QuerySet in annotate() is not supported on MongoDB
655+
"aggregation.tests.AggregateTestCase.test_group_by_subquery_annotation",
656+
}
639657
}
640658

641659
@cached_property

django_mongodb/functions.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django.db import NotSupportedError
22
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
3-
from django.db.models.expressions import Case, Func, Value, When
3+
from django.db.models.expressions import Case, Func, Star, Value, When
4+
from django.db.models.functions import Now
45
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
56
from django.db.models.functions.datetime import (
67
Extract,
@@ -122,12 +123,13 @@ def count(self, compiler, connection, **extra_context): # noqa: ARG001
122123
condition = When(self.filter, then=Value(1))
123124
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
124125
node = copy
126+
cond = process_lhs(node, compiler, connection)
125127
else:
126128
node = self
127-
# lhs_mql = process_lhs(self, compiler, connection)
128-
lhs_mql = process_lhs(node, compiler, connection)
129-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
130-
return {"$sum": {"$cond": {"if": null_cond, "then": 0, "else": 1}}}
129+
lhs_mql = process_lhs(self, compiler, connection)
130+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
131+
cond = {"$cond": {"if": null_cond, "then": 0, "else": 1}}
132+
return {"$sum": cond}
131133

132134

133135
def extract(self, compiler, connection):
@@ -163,6 +165,10 @@ def log(self, compiler, connection):
163165
return func(clone, compiler, connection)
164166

165167

168+
def now(self, compiler, connection): # noqa: ARG001
169+
return "$$NOW"
170+
171+
166172
def null_if(self, compiler, connection):
167173
"""Return None if expr1==expr2 else expr1."""
168174
expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions())
@@ -196,6 +202,10 @@ def round_(self, compiler, connection):
196202
return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]}
197203

198204

205+
def star(self, compiler, connection): # noqa: ARG001
206+
return {"$literal": True}
207+
208+
199209
def str_index(self, compiler, connection):
200210
lhs = process_lhs(self, compiler, connection)
201211
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
@@ -243,10 +253,12 @@ def register_functions():
243253
Log.as_mql = log
244254
Lower.as_mql = perserve_null("toLower")
245255
LTrim.as_mql = trim("ltrim")
256+
Now.as_mql = now
246257
NullIf.as_mql = null_if
247258
Replace.as_mql = replace
248259
Round.as_mql = round_
249260
RTrim.as_mql = trim("rtrim")
261+
Star.as_mql = star
250262
StrIndex.as_mql = str_index
251263
Substr.as_mql = substr
252264
Trim.as_mql = trim("trim")

django_mongodb/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, compiler):
4747
self.collection = self.compiler.get_collection()
4848
self.collection_name = self.compiler.collection_name
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
50-
self.sub_query = None
50+
self.subquery = None
5151
self.lookup_pipeline = None
5252
self.annotation_stage = None
5353
self.project_fields = None
@@ -94,7 +94,7 @@ def delete(self):
9494
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9595

9696
def get_pipeline(self, count=False, limit=None, skip=None):
97-
pipeline = [] if self.sub_query is None else self.sub_query.get_pipeline()
97+
pipeline = [] if self.subquery is None else self.subquery.get_pipeline()
9898
if self.lookup_pipeline:
9999
pipeline.extend(self.lookup_pipeline)
100100
if self.mongo_query:

0 commit comments

Comments
 (0)