Skip to content

Commit 8a5be6e

Browse files
WaVEVtimgraham
authored andcommitted
Wrapping group when no group by.
1 parent a4410a1 commit 8a5be6e

File tree

4 files changed

+116
-19
lines changed

4 files changed

+116
-19
lines changed

django_mongodb/compiler.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@ class SQLCompiler(compiler.SQLCompiler):
1919
query_class = MongoQuery
2020
_group_pipeline = None
2121

22+
@staticmethod
23+
def _random_separtor():
24+
import random
25+
import string
26+
27+
size = 6
28+
chars = string.ascii_uppercase + string.digits
29+
return "".join(random.choice(chars) for _ in range(size)) # noqa: S311
30+
2231
def pre_sql_setup(self, with_col_aliases=False):
2332
pre_setup = super().pre_sql_setup(with_col_aliases=with_col_aliases)
2433
self.annotations = {}
25-
# mongo_having = self.having.copy() if self.having else None
2634
group = {}
2735
group_expressions = set()
2836
aggregation_idx = 1
@@ -38,13 +46,14 @@ def pre_sql_setup(self, with_col_aliases=False):
3846
aggregation_idx += 1
3947
else:
4048
alias = target
41-
group_expressions |= set(sub_expr.get_group_by_cols())
4249
group[alias] = sub_expr.as_mql(self, self.connection)
4350
column_target = expr.output_field.__class__()
51+
column_target.db_column = alias
4452
column_target.set_attributes_from_name(alias)
4553
replacements[sub_expr] = Col(self.collection_name, column_target)
4654
result_expr = expr.replace_expressions(replacements)
4755
all_replacements.update(replacements)
56+
group_expressions |= set(expr.get_group_by_cols())
4857
self.annotations[target] = result_expr
4958
if group:
5059
order_by = self.get_order_by()
@@ -59,24 +68,69 @@ def pre_sql_setup(self, with_col_aliases=False):
5968
if isinstance(self.query.group_by, tuple | list):
6069
group_expressions |= set(self.query.group_by)
6170

71+
all_strings = "".join(
72+
str(col.as_mql(self, self.connection)) for col in group_expressions
73+
)
74+
75+
while True:
76+
random_string = self._random_separtor()
77+
if random_string not in all_strings:
78+
break
79+
SEPARATOR = f"__{random_string}__"
80+
81+
def _ccc(col):
82+
if self.collection_name == col.alias:
83+
return col.target.column
84+
return f"{col.alias}{SEPARATOR}{col.target.column}"
85+
6286
ids = (
6387
None
6488
if not group_expressions
6589
else {
66-
col.target.column: col.as_mql(self, self.connection)
90+
_ccc(col): col.as_mql(self, self.connection)
6791
# expression aren't needed in the group by clouse ()
6892
for col in group_expressions
6993
if isinstance(col, Col)
7094
}
7195
)
72-
group["_id"] = ids
73-
pipeline = [{"$group": group}]
74-
if ids:
96+
pipeline = []
97+
if ids is None:
98+
group["_id"] = None
99+
pipeline.append({"$facet": {"group": [{"$group": group}]}})
75100
pipeline.append(
76-
{"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
101+
{
102+
"$project": {
103+
key: {
104+
"$getField": {
105+
"input": {"$arrayElemAt": ["$group", 0]},
106+
"field": key,
107+
}
108+
}
109+
for key in group
110+
}
111+
}
77112
)
78-
if "_id" not in ids:
113+
else:
114+
group["_id"] = ids
115+
pipeline.append({"$group": group})
116+
sets = {}
117+
for key in ids:
118+
value = f"$_id.{key}"
119+
if SEPARATOR in key:
120+
subtable, field = key.split(SEPARATOR)
121+
if subtable not in sets:
122+
sets[subtable] = {}
123+
sets[subtable][field] = value
124+
else:
125+
sets[key] = value
126+
127+
pipeline.append(
128+
# {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
129+
{"$addFields": sets}
130+
)
131+
if "_id" not in sets:
79132
pipeline.append({"$unset": "_id"})
133+
80134
if self.having:
81135
pipeline.append(
82136
{
@@ -250,14 +304,14 @@ def build_query(self, columns=None):
250304
query = self.query_class(self)
251305
query.aggregation_stage = self.get_aggregation_pipeline()
252306
query.lookup_pipeline = self.get_lookup_pipeline()
253-
query.project_fields = self.get_project_fields(columns)
307+
query.order_by(self._get_ordering())
308+
query.project_fields = self.get_project_fields(columns, ordering=query.ordering)
254309
try:
255310
query.mongo_query = (
256311
{"$expr": self.where.as_mql(self, self.connection)} if self.where else None
257312
)
258313
except FullResultSet:
259314
query.mongo_query = {}
260-
query.order_by(self._get_ordering())
261315
return query
262316

263317
def get_columns(self):
@@ -369,7 +423,7 @@ def _get_aggregate_expressions(self, expr):
369423
def get_aggregation_pipeline(self):
370424
return self._group_pipeline
371425

372-
def get_project_fields(self, columns=None):
426+
def get_project_fields(self, columns=None, ordering=None):
373427
fields = {}
374428
for name, expr in columns or []:
375429
try:
@@ -393,6 +447,10 @@ def get_project_fields(self, columns=None):
393447
if self.query.alias_refcount[alias] and self.collection_name != alias:
394448
fields[alias] = 1
395449

450+
for column, _ in ordering or []:
451+
if column not in fields:
452+
fields[column] = 1
453+
396454
return fields
397455

398456

@@ -513,7 +571,16 @@ def build_query(self, columns=None):
513571
elide_empty=self.elide_empty,
514572
)
515573
compiler.pre_sql_setup(with_col_aliases=False)
516-
query.sub_query = compiler.build_query()
574+
columns = (
575+
compiler.get_columns()
576+
if compiler.query.annotations or not compiler.query.default_cols
577+
else None
578+
)
579+
subquery = compiler.build_query(
580+
# Avoid $project (columns=None) if unneeded.
581+
columns
582+
)
583+
query.subquery = subquery
517584
return query
518585

519586
def _make_result(self, result, columns=None):

django_mongodb/features.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,24 @@ def django_test_expected_failures(self):
594594
"Test not applicable for MongoDB's SQLCompiler.": {
595595
"queries.test_iterator.QuerySetIteratorTests",
596596
},
597+
"skip agg": {
598+
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists",
599+
# Custom aggregations:
600+
"aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned",
601+
"aggregation.tests.AggregateTestCase.test_add_implementation",
602+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database",
603+
# SQL custom values.
604+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database",
605+
# No sql generate
606+
"aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned",
607+
"aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_pruned",
608+
# PI expression:
609+
"aggregation.tests.AggregateTestCase.test_aggregation_default_using_decimal_from_database",
610+
# check sql query performance
611+
"aggregation.tests.AggregateAnnotationPruningTests.test_referenced_aggregate_annotation_kept",
612+
# Using a QuerySet in annotate() is not supported on MongoDB
613+
"aggregation.tests.AggregateTestCase.test_group_by_subquery_annotation",
614+
}
597615
}
598616

599617
@cached_property

django_mongodb/functions.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from django.db import NotSupportedError
22
from django.db.models.aggregates import Aggregate, Avg, Count, Max, Min, StdDev, Sum, Variance
3-
from django.db.models.expressions import Case, Func, Value, When
3+
from django.db.models.expressions import Case, Func, Star, Value, When
4+
from django.db.models.functions import Now
45
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
56
from django.db.models.functions.datetime import (
67
Extract,
@@ -122,12 +123,13 @@ def count(self, compiler, connection, **extra_context): # noqa: ARG001
122123
condition = When(self.filter, then=Value(1))
123124
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
124125
node = copy
126+
cond = process_lhs(node, compiler, connection)
125127
else:
126128
node = self
127-
# lhs_mql = process_lhs(self, compiler, connection)
128-
lhs_mql = process_lhs(node, compiler, connection)
129-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
130-
return {"$sum": {"$cond": {"if": null_cond, "then": 0, "else": 1}}}
129+
lhs_mql = process_lhs(self, compiler, connection)
130+
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
131+
cond = {"$cond": {"if": null_cond, "then": 0, "else": 1}}
132+
return {"$sum": cond}
131133

132134

133135
def extract(self, compiler, connection):
@@ -163,6 +165,10 @@ def log(self, compiler, connection):
163165
return func(clone, compiler, connection)
164166

165167

168+
def now(self, compiler, connection): # noqa: ARG001
169+
return "$$NOW"
170+
171+
166172
def null_if(self, compiler, connection):
167173
"""Return None if expr1==expr2 else expr1."""
168174
expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions())
@@ -196,6 +202,10 @@ def round_(self, compiler, connection):
196202
return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]}
197203

198204

205+
def star(self, compiler, connection): # noqa: ARG001
206+
return {"$literal": True}
207+
208+
199209
def str_index(self, compiler, connection):
200210
lhs = process_lhs(self, compiler, connection)
201211
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
@@ -243,10 +253,12 @@ def register_functions():
243253
Log.as_mql = log
244254
Lower.as_mql = perserve_null("toLower")
245255
LTrim.as_mql = trim("ltrim")
256+
Now.as_mql = now
246257
NullIf.as_mql = null_if
247258
Replace.as_mql = replace
248259
Round.as_mql = round_
249260
RTrim.as_mql = trim("rtrim")
261+
Star.as_mql = star
250262
StrIndex.as_mql = str_index
251263
Substr.as_mql = substr
252264
Trim.as_mql = trim("trim")

django_mongodb/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, compiler):
4747
self.collection = self.compiler.get_collection()
4848
self.collection_name = self.compiler.collection_name
4949
self.mongo_query = getattr(compiler.query, "raw_query", {})
50-
self.sub_query = None
50+
self.subquery = None
5151
self.lookup_pipeline = None
5252
self.annotation_stage = None
5353
self.project_fields = None
@@ -90,7 +90,7 @@ def delete(self):
9090
return self.collection.delete_many(self.mongo_query, **options).deleted_count
9191

9292
def get_pipeline(self, count=False, limit=None, skip=None):
93-
pipeline = [] if self.sub_query is None else self.sub_query.get_pipeline()
93+
pipeline = [] if self.subquery is None else self.subquery.get_pipeline()
9494
if self.lookup_pipeline:
9595
pipeline.extend(self.lookup_pipeline)
9696
if self.mongo_query:

0 commit comments

Comments
 (0)