Skip to content

Commit 12fd2cd

Browse files
committed
Support union.
1 parent ee83b2a commit 12fd2cd

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

django_mongodb/compiler.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -350,21 +350,19 @@ def build_query(self, columns=None):
350350
self.check_query()
351351
query = self.query_class(self)
352352
ordering_fields, sort_ordering, extra_fields = self._get_ordering()
353-
query.project_fields = self.get_project_fields(columns, ordering_fields)
354353
query.ordering = sort_ordering
355-
# If columns is None, then get_project_fields() won't add
356-
# ordering_fields to $project. Use $addFields (extra_fields) instead.
357-
if columns is None:
358-
extra_fields += ordering_fields
359-
if extra_fields:
360-
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
361354
if self.query.combinator:
362355
if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"):
363356
raise NotSupportedError(
364357
f"{self.query.combinator} is not supported on this database backend."
365358
)
366359
query.combinator_pipeline = self.get_combinator_queries()
367360
else:
361+
query.project_fields = self.get_project_fields(columns, ordering_fields)
362+
# If columns is None, then get_project_fields() won't add
363+
# ordering_fields to $project. Use $addFields (extra_fields) instead.
364+
if columns is None:
365+
extra_fields += ordering_fields
368366
query.lookup_pipeline = self.get_lookup_pipeline()
369367
where = self.get_where()
370368
try:
@@ -373,6 +371,8 @@ def build_query(self, columns=None):
373371
query.mongo_query = {}
374372
else:
375373
query.mongo_query = {"$expr": expr}
374+
if extra_fields:
375+
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
376376
return query
377377

378378
def get_columns(self):
@@ -381,10 +381,8 @@ def get_columns(self):
381381
which should be loaded by the query.
382382
"""
383383
select_mask = self.query.get_select_mask()
384-
columns = filter(
385-
# The extra order by columns are handled by order_by_objs variables.
386-
lambda col: not isinstance(col, OrderBy),
387-
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select,
384+
columns = (
385+
self.get_default_columns(select_mask) if self.query.default_cols else self.query.select
388386
)
389387
# Populate QuerySet.select_related() data.
390388
related_columns = []
@@ -400,6 +398,9 @@ def project_field(column):
400398
if hasattr(column, "target"):
401399
# column is a Col.
402400
target = column.target.column
401+
# Handle Order by columns as refs columns.
402+
elif isinstance(column, OrderBy) and isinstance(column.expression, Ref):
403+
target = column.expression.refs
403404
else:
404405
# column is a Transform in values()/values_list() that needs a
405406
# name for $proj.
@@ -427,6 +428,8 @@ def get_combinator_queries(self):
427428
query.get_compiler(self.using, self.connection, self.elide_empty)
428429
for query in self.query.combined_queries
429430
]
431+
main_query_columns = self.get_columns()
432+
main_query_fields, _ = zip(*main_query_columns, strict=True)
430433
for compiler_ in compilers:
431434
try:
432435
# If the columns list is limited, then all combined queries
@@ -443,8 +446,10 @@ def get_combinator_queries(self):
443446
)
444447
compiler_.pre_sql_setup()
445448
# Standardize columns as main query required.
446-
_, exprs = zip(*compiler_.get_columns(), strict=True)
447-
columns = tuple(zip(self.query.values_select, exprs, strict=True))
449+
columns = compiler_.get_columns()
450+
if self.query.values_select:
451+
_, exprs = zip(*columns, strict=True)
452+
columns = tuple(zip(main_query_fields, exprs, strict=True))
448453
parts.append((compiler_.build_query(columns), compiler_.collection_name))
449454

450455
except EmptyResultSet:
@@ -464,12 +469,14 @@ def get_combinator_queries(self):
464469
)
465470
if not self.query.combinator_all:
466471
ids = {}
467-
for alias, expr in self.get_columns():
468-
ids[alias] = (
469-
expr.as_mql(self, self.connection)
470-
if isinstance(expr, Col | Ref)
471-
else f"${alias}"
472-
)
472+
for alias, expr in main_query_columns:
473+
collection = expr.alias if isinstance(expr, Col) else None
474+
if collection and collection != self.collection_name:
475+
ids[
476+
f"{expr.alias}{self.GROUP_SEPARATOR}{expr.target.column}"
477+
] = expr.as_mql(self, self.connection)
478+
else:
479+
ids[alias] = f"${alias}"
473480
combinator_pipeline.append({"$group": {"_id": ids}})
474481
projected_fields = defaultdict(dict)
475482
for key in ids:

django_mongodb/query.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, OrderBy, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -80,6 +80,8 @@ def get_pipeline(self):
8080
pipeline.extend(self.aggregation_pipeline)
8181
if self.project_fields:
8282
pipeline.append({"$project": self.project_fields})
83+
if self.combinator_pipeline:
84+
pipeline.extend(self.combinator_pipeline)
8385
if self.extra_fields:
8486
pipeline.append({"$addFields": self.extra_fields})
8587
if self.ordering:
@@ -88,8 +90,6 @@ def get_pipeline(self):
8890
pipeline.append({"$skip": self.query.low_mark})
8991
if self.query.high_mark is not None:
9092
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})
91-
if self.combinator_pipeline:
92-
pipeline.extend(self.combinator_pipeline)
9393

9494
return subquery_pipeline + pipeline
9595

@@ -171,6 +171,10 @@ def join(self, compiler, connection):
171171
return lookup_pipeline
172172

173173

174+
def orderby(self, compiler, connection):
175+
return self.expression.as_mql(compiler, connection)
176+
177+
174178
def where_node(self, compiler, connection):
175179
if self.connector == AND:
176180
full_needed, empty_needed = len(self.children), 1
@@ -236,4 +240,5 @@ def where_node(self, compiler, connection):
236240
def register_nodes():
237241
Join.as_mql = join
238242
NothingNode.as_mql = NothingNode.as_sql
243+
OrderBy.as_mql = orderby
239244
WhereNode.as_mql = where_node

0 commit comments

Comments
 (0)