Skip to content

Commit ee83b2a

Browse files
committed
Improvements.
1 parent 7cb8cb5 commit ee83b2a

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

django_mongodb/compiler.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.db import IntegrityError, NotSupportedError
88
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
10-
from django.db.models.expressions import Case, Col, Ref, Value, When
10+
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
1212
from django.db.models.functions.math import Power
1313
from django.db.models.lookups import IsNull
@@ -381,8 +381,10 @@ def get_columns(self):
381381
which should be loaded by the query.
382382
"""
383383
select_mask = self.query.get_select_mask()
384-
columns = (
385-
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
384+
columns = filter(
385+
# The extra order by columns are handled by order_by_objs variables.
386+
lambda col: not isinstance(col, OrderBy),
387+
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select,
386388
)
387389
# Populate QuerySet.select_related() data.
388390
related_columns = []
@@ -439,13 +441,10 @@ def get_combinator_queries(self):
439441
*self.query.annotation_select,
440442
)
441443
)
442-
compiler_.pre_sql_setup(with_col_aliases=False)
443-
# Avoid $project (columns=None) if unneeded.
444-
columns = (
445-
compiler_.get_columns()
446-
if compiler_.query.annotations or not compiler_.query.default_cols
447-
else None
448-
)
444+
compiler_.pre_sql_setup()
445+
# Standardize columns as main query required.
446+
_, exprs = zip(*compiler_.get_columns(), strict=True)
447+
columns = tuple(zip(self.query.values_select, exprs, strict=True))
449448
parts.append((compiler_.build_query(columns), compiler_.collection_name))
450449

451450
except EmptyResultSet:
@@ -454,7 +453,9 @@ def get_combinator_queries(self):
454453
if self.query.combinator == "union":
455454
continue
456455
raise
457-
456+
# Raise EmptyResultSet if all the combinator queries are empty.
457+
if not parts:
458+
raise EmptyResultSet
458459
combinator_pipeline = parts.pop(0)[0].get_pipeline() if parts else None
459460
if self.query.combinator == "union":
460461
for part, collection in parts:
@@ -463,13 +464,26 @@ def get_combinator_queries(self):
463464
)
464465
if not self.query.combinator_all:
465466
ids = {}
466-
annotation_group_idx = itertools.count(start=1)
467-
for _, expr in self.get_columns():
468-
alias, replacement = self._get_group_alias_column(
469-
expr, annotation_group_idx
467+
for alias, expr in self.get_columns():
468+
ids[alias] = (
469+
expr.as_mql(self, self.connection)
470+
if isinstance(expr, Col | Ref)
471+
else f"${alias}"
470472
)
471-
ids[alias] = expr.as_mql(self, self.connection)
472473
combinator_pipeline.append({"$group": {"_id": ids}})
474+
projected_fields = defaultdict(dict)
475+
for key in ids:
476+
value = f"$_id.{key}"
477+
if self.GROUP_SEPARATOR in key:
478+
table, field = key.split(self.GROUP_SEPARATOR)
479+
projected_fields[table][field] = value
480+
else:
481+
projected_fields[key] = value
482+
# Convert defaultdict to dict so it doesn't appear as
483+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
484+
combinator_pipeline.append({"$addFields": dict(projected_fields)})
485+
if "_id" not in projected_fields:
486+
combinator_pipeline.append({"$unset": "_id"})
473487
else:
474488
raise NotSupportedError(f"Combinator {self.query.combinator} isn't supported.")
475489
return combinator_pipeline

0 commit comments

Comments
 (0)