Skip to content

Commit 76e2c31

Browse files
committed
Support Union.
1 parent 89aa87f commit 76e2c31

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

django_mongodb/compiler.py

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.db import IntegrityError, NotSupportedError
88
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
10-
from django.db.models.expressions import Case, Col, Ref, Value, When
10+
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
1212
from django.db.models.functions.math import Power
1313
from django.db.models.lookups import IsNull
@@ -349,23 +349,30 @@ def build_query(self, columns=None):
349349
"""Check if the query is supported and prepare a MongoQuery."""
350350
self.check_query()
351351
query = self.query_class(self)
352-
query.lookup_pipeline = self.get_lookup_pipeline()
353352
ordering_fields, sort_ordering, extra_fields = self._get_ordering()
354-
query.project_fields = self.get_project_fields(columns, ordering_fields)
355353
query.ordering = sort_ordering
356-
# If columns is None, then get_project_fields() won't add
357-
# ordering_fields to $project. Use $addFields (extra_fields) instead.
358-
if columns is None:
359-
extra_fields += ordering_fields
354+
if self.query.combinator:
355+
if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"):
356+
raise NotSupportedError(
357+
f"{self.query.combinator} is not supported on this database backend."
358+
)
359+
query.combinator_pipeline = self.get_combinator_queries()
360+
else:
361+
query.project_fields = self.get_project_fields(columns, ordering_fields)
362+
# If columns is None, then get_project_fields() won't add
363+
# ordering_fields to $project. Use $addFields (extra_fields) instead.
364+
if columns is None:
365+
extra_fields += ordering_fields
366+
query.lookup_pipeline = self.get_lookup_pipeline()
367+
where = self.get_where()
368+
try:
369+
expr = where.as_mql(self, self.connection) if where else {}
370+
except FullResultSet:
371+
query.mongo_query = {}
372+
else:
373+
query.mongo_query = {"$expr": expr}
360374
if extra_fields:
361375
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
362-
where = self.get_where()
363-
try:
364-
expr = where.as_mql(self, self.connection) if where else {}
365-
except FullResultSet:
366-
query.mongo_query = {}
367-
else:
368-
query.mongo_query = {"$expr": expr}
369376
return query
370377

371378
def get_columns(self):
@@ -391,6 +398,9 @@ def project_field(column):
391398
if hasattr(column, "target"):
392399
# column is a Col.
393400
target = column.target.column
401+
# Handle Order by columns as refs columns.
402+
elif isinstance(column, OrderBy) and isinstance(column.expression, Ref):
403+
target = column.expression.refs
394404
else:
395405
# column is a Transform in values()/values_list() that needs a
396406
# name for $proj.
@@ -412,6 +422,79 @@ def collection_name(self):
412422
def collection(self):
413423
return self.connection.get_collection(self.collection_name)
414424

425+
def get_combinator_queries(self):
426+
parts = []
427+
compilers = [
428+
query.get_compiler(self.using, self.connection, self.elide_empty)
429+
for query in self.query.combined_queries
430+
]
431+
main_query_columns = self.get_columns()
432+
main_query_fields, _ = zip(*main_query_columns, strict=True)
433+
for compiler_ in compilers:
434+
try:
435+
# If the columns list is limited, then all combined queries
436+
# must have the same columns list. Set the selects defined on
437+
# the query on all combined queries, if not already set.
438+
if not compiler_.query.values_select and self.query.values_select:
439+
compiler_.query = compiler_.query.clone()
440+
compiler_.query.set_values(
441+
(
442+
*self.query.extra_select,
443+
*self.query.values_select,
444+
*self.query.annotation_select,
445+
)
446+
)
447+
compiler_.pre_sql_setup()
448+
# Standardize columns as main query required.
449+
columns = compiler_.get_columns()
450+
if self.query.values_select:
451+
_, exprs = zip(*columns, strict=True)
452+
columns = tuple(zip(main_query_fields, exprs, strict=True))
453+
parts.append((compiler_.build_query(columns), compiler_.collection_name))
454+
455+
except EmptyResultSet:
456+
# Omit the empty queryset with UNION and with DIFFERENCE if the
457+
# first queryset is nonempty.
458+
if self.query.combinator == "union":
459+
continue
460+
raise
461+
# Raise EmptyResultSet if all the combinator queries are empty.
462+
if not parts:
463+
raise EmptyResultSet
464+
combinator_pipeline = parts.pop(0)[0].get_pipeline()
465+
if self.query.combinator == "union":
466+
for part, collection in parts:
467+
combinator_pipeline.append(
468+
{"$unionWith": {"coll": collection, "pipeline": part.get_pipeline()}}
469+
)
470+
if not self.query.combinator_all:
471+
ids = {}
472+
for alias, expr in main_query_columns:
473+
collection = expr.alias if isinstance(expr, Col) else None
474+
if collection and collection != self.collection_name:
475+
ids[
476+
f"{expr.alias}{self.GROUP_SEPARATOR}{expr.target.column}"
477+
] = expr.as_mql(self, self.connection)
478+
else:
479+
ids[alias] = f"${alias}"
480+
combinator_pipeline.append({"$group": {"_id": ids}})
481+
projected_fields = defaultdict(dict)
482+
for key in ids:
483+
value = f"$_id.{key}"
484+
if self.GROUP_SEPARATOR in key:
485+
table, field = key.split(self.GROUP_SEPARATOR)
486+
projected_fields[table][field] = value
487+
else:
488+
projected_fields[key] = value
489+
# Convert defaultdict to dict so it doesn't appear as
490+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
491+
combinator_pipeline.append({"$addFields": dict(projected_fields)})
492+
if "_id" not in projected_fields:
493+
combinator_pipeline.append({"$unset": "_id"})
494+
else:
495+
raise NotSupportedError(f"Combinator {self.query.combinator} isn't supported.")
496+
return combinator_pipeline
497+
415498
def get_lookup_pipeline(self):
416499
result = []
417500
for alias in tuple(self.query.alias_map):

django_mongodb/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2323
supports_select_difference = False
2424
supports_select_intersection = False
2525
# Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/72
26-
supports_select_union = False
26+
supports_select_union = True
2727
supports_sequence_reset = False
2828
supports_table_check_constraints = False
2929
supports_temporal_subtraction = True

django_mongodb/query.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, OrderBy, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -51,6 +51,7 @@ def __init__(self, compiler):
5151
self.project_fields = None
5252
self.aggregation_pipeline = compiler.aggregation_pipeline
5353
self.extra_fields = None
54+
self.combinator_pipeline = None
5455

5556
def __repr__(self):
5657
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -78,6 +79,8 @@ def get_pipeline(self):
7879
pipeline.extend(self.aggregation_pipeline)
7980
if self.project_fields:
8081
pipeline.append({"$project": self.project_fields})
82+
if self.combinator_pipeline:
83+
pipeline.extend(self.combinator_pipeline)
8184
if self.extra_fields:
8285
pipeline.append({"$addFields": self.extra_fields})
8386
if self.ordering:
@@ -166,6 +169,10 @@ def join(self, compiler, connection):
166169
return lookup_pipeline
167170

168171

172+
def orderby(self, compiler, connection):
173+
return self.expression.as_mql(compiler, connection)
174+
175+
169176
def where_node(self, compiler, connection):
170177
if self.connector == AND:
171178
full_needed, empty_needed = len(self.children), 1
@@ -231,4 +238,5 @@ def where_node(self, compiler, connection):
231238
def register_nodes():
232239
Join.as_mql = join
233240
NothingNode.as_mql = NothingNode.as_sql
241+
OrderBy.as_mql = orderby
234242
WhereNode.as_mql = where_node

0 commit comments

Comments
 (0)