Skip to content

Commit 7cb8cb5

Browse files
committed
Support Union query, first approach.
1 parent 89aa87f commit 7cb8cb5

File tree

3 files changed

+77
-10
lines changed

3 files changed

+77
-10
lines changed

django_mongodb/compiler.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,6 @@ def build_query(self, columns=None):
349349
"""Check if the query is supported and prepare a MongoQuery."""
350350
self.check_query()
351351
query = self.query_class(self)
352-
query.lookup_pipeline = self.get_lookup_pipeline()
353352
ordering_fields, sort_ordering, extra_fields = self._get_ordering()
354353
query.project_fields = self.get_project_fields(columns, ordering_fields)
355354
query.ordering = sort_ordering
@@ -359,13 +358,21 @@ def build_query(self, columns=None):
359358
extra_fields += ordering_fields
360359
if extra_fields:
361360
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
362-
where = self.get_where()
363-
try:
364-
expr = where.as_mql(self, self.connection) if where else {}
365-
except FullResultSet:
366-
query.mongo_query = {}
361+
if self.query.combinator:
362+
if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"):
363+
raise NotSupportedError(
364+
f"{self.query.combinator} is not supported on this database backend."
365+
)
366+
query.combinator_pipeline = self.get_combinator_queries()
367367
else:
368-
query.mongo_query = {"$expr": expr}
368+
query.lookup_pipeline = self.get_lookup_pipeline()
369+
where = self.get_where()
370+
try:
371+
expr = where.as_mql(self, self.connection) if where else {}
372+
except FullResultSet:
373+
query.mongo_query = {}
374+
else:
375+
query.mongo_query = {"$expr": expr}
369376
return query
370377

371378
def get_columns(self):
@@ -412,6 +419,61 @@ def collection_name(self):
412419
def collection(self):
413420
return self.connection.get_collection(self.collection_name)
414421

422+
def get_combinator_queries(self):
423+
parts = []
424+
compilers = [
425+
query.get_compiler(self.using, self.connection, self.elide_empty)
426+
for query in self.query.combined_queries
427+
]
428+
for compiler_ in compilers:
429+
try:
430+
# If the columns list is limited, then all combined queries
431+
# must have the same columns list. Set the selects defined on
432+
# the query on all combined queries, if not already set.
433+
if not compiler_.query.values_select and self.query.values_select:
434+
compiler_.query = compiler_.query.clone()
435+
compiler_.query.set_values(
436+
(
437+
*self.query.extra_select,
438+
*self.query.values_select,
439+
*self.query.annotation_select,
440+
)
441+
)
442+
compiler_.pre_sql_setup(with_col_aliases=False)
443+
# Avoid $project (columns=None) if unneeded.
444+
columns = (
445+
compiler_.get_columns()
446+
if compiler_.query.annotations or not compiler_.query.default_cols
447+
else None
448+
)
449+
parts.append((compiler_.build_query(columns), compiler_.collection_name))
450+
451+
except EmptyResultSet:
452+
# Omit the empty queryset with UNION and with DIFFERENCE if the
453+
# first queryset is nonempty.
454+
if self.query.combinator == "union":
455+
continue
456+
raise
457+
458+
combinator_pipeline = parts.pop(0)[0].get_pipeline() if parts else None
459+
if self.query.combinator == "union":
460+
for part, collection in parts:
461+
combinator_pipeline.append(
462+
{"$unionWith": {"coll": collection, "pipeline": part.get_pipeline()}}
463+
)
464+
if not self.query.combinator_all:
465+
ids = {}
466+
annotation_group_idx = itertools.count(start=1)
467+
for _, expr in self.get_columns():
468+
alias, replacement = self._get_group_alias_column(
469+
expr, annotation_group_idx
470+
)
471+
ids[alias] = expr.as_mql(self, self.connection)
472+
combinator_pipeline.append({"$group": {"_id": ids}})
473+
else:
474+
raise NotSupportedError(f"Combinator {self.query.combinator} isn't supported.")
475+
return combinator_pipeline
476+
415477
def get_lookup_pipeline(self):
416478
result = []
417479
for alias in tuple(self.query.alias_map):

django_mongodb/features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2323
supports_select_difference = False
2424
supports_select_intersection = False
2525
# Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/72
26-
supports_select_union = False
26+
supports_select_union = True
2727
supports_sequence_reset = False
2828
supports_table_check_constraints = False
2929
supports_temporal_subtraction = True

django_mongodb/query.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, compiler):
5151
self.project_fields = None
5252
self.aggregation_pipeline = compiler.aggregation_pipeline
5353
self.extra_fields = None
54+
self.combinator_pipeline = None
5455

5556
def __repr__(self):
5657
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -69,7 +70,8 @@ def get_cursor(self):
6970
return self.collection.aggregate(self.get_pipeline())
7071

7172
def get_pipeline(self):
72-
pipeline = self.subquery.get_pipeline() if self.subquery else []
73+
subquery_pipeline = self.subquery.get_pipeline() if self.subquery else []
74+
pipeline = []
7375
if self.lookup_pipeline:
7476
pipeline.extend(self.lookup_pipeline)
7577
if self.mongo_query:
@@ -86,7 +88,10 @@ def get_pipeline(self):
8688
pipeline.append({"$skip": self.query.low_mark})
8789
if self.query.high_mark is not None:
8890
pipeline.append({"$limit": self.query.high_mark - self.query.low_mark})
89-
return pipeline
91+
if self.combinator_pipeline:
92+
pipeline.extend(self.combinator_pipeline)
93+
94+
return subquery_pipeline + pipeline
9095

9196

9297
def join(self, compiler, connection):

0 commit comments

Comments
 (0)