Skip to content

Commit df2cddd

Browse files
WaVEVtimgraham
authored andcommitted
add support for QuerySet.union()
1 parent 93ac126 commit df2cddd

File tree

3 files changed

+135
-32
lines changed

3 files changed

+135
-32
lines changed

django_mongodb/compiler.py

Lines changed: 124 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.db import IntegrityError, NotSupportedError
88
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
10-
from django.db.models.expressions import Case, Col, Ref, Value, When
10+
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
1212
from django.db.models.functions.math import Power
1313
from django.db.models.lookups import IsNull
@@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs):
3232
# A list of OrderBy objects for this query.
3333
self.order_by_objs = None
3434

35+
def _unfold_column(self, col):
36+
"""
37+
Flatten a field by returning its target or by replacing dots with
38+
GROUP_SEPARATOR for foreign fields.
39+
"""
40+
if self.collection_name == col.alias:
41+
return col.target.column
42+
# If this is a foreign field, replace the normal dot (.) with
43+
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"
45+
46+
def _fold_columns(self, unfold_columns):
47+
"""
48+
Convert flat columns into a nested dictionary, grouping fields by
49+
table name.
50+
"""
51+
result = defaultdict(dict)
52+
for key in unfold_columns:
53+
value = f"$_id.{key}"
54+
if self.GROUP_SEPARATOR in key:
55+
table, field = key.split(self.GROUP_SEPARATOR)
56+
result[table][field] = value
57+
else:
58+
result[key] = value
59+
# Convert defaultdict to dict so it doesn't appear as
60+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
61+
return dict(result)
62+
3563
def _get_group_alias_column(self, expr, annotation_group_idx):
3664
"""Generate a dummy field for use in the ids fields in $group."""
3765
replacement = None
@@ -42,11 +70,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4270
alias = f"__annotation_group{next(annotation_group_idx)}"
4371
col = self._get_column_from_expression(expr, alias)
4472
replacement = col
45-
if self.collection_name == col.alias:
46-
return col.target.column, replacement
47-
# If this is a foreign field, replace the normal dot (.) with
48-
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49-
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement
73+
return self._unfold_column(col), replacement
5074

5175
def _get_column_from_expression(self, expr, alias):
5276
"""
@@ -186,17 +210,8 @@ def _build_aggregation_pipeline(self, ids, group):
186210
else:
187211
group["_id"] = ids
188212
pipeline.append({"$group": group})
189-
projected_fields = defaultdict(dict)
190-
for key in ids:
191-
value = f"$_id.{key}"
192-
if self.GROUP_SEPARATOR in key:
193-
table, field = key.split(self.GROUP_SEPARATOR)
194-
projected_fields[table][field] = value
195-
else:
196-
projected_fields[key] = value
197-
# Convert defaultdict to dict so it doesn't appear as
198-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
199-
pipeline.append({"$addFields": dict(projected_fields)})
213+
projected_fields = self._fold_columns(ids)
214+
pipeline.append({"$addFields": projected_fields})
200215
if "_id" not in projected_fields:
201216
pipeline.append({"$unset": "_id"})
202217
return pipeline
@@ -349,23 +364,30 @@ def build_query(self, columns=None):
349364
"""Check if the query is supported and prepare a MongoQuery."""
350365
self.check_query()
351366
query = self.query_class(self)
352-
query.lookup_pipeline = self.get_lookup_pipeline()
353367
ordering_fields, sort_ordering, extra_fields = self._get_ordering()
354-
query.project_fields = self.get_project_fields(columns, ordering_fields)
355368
query.ordering = sort_ordering
356-
# If columns is None, then get_project_fields() won't add
357-
# ordering_fields to $project. Use $addFields (extra_fields) instead.
358-
if columns is None:
359-
extra_fields += ordering_fields
369+
if self.query.combinator:
370+
if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"):
371+
raise NotSupportedError(
372+
f"{self.query.combinator} is not supported on this database backend."
373+
)
374+
query.combinator_pipeline = self.get_combinator_queries()
375+
else:
376+
query.project_fields = self.get_project_fields(columns, ordering_fields)
377+
# If columns is None, then get_project_fields() won't add
378+
# ordering_fields to $project. Use $addFields (extra_fields) instead.
379+
if columns is None:
380+
extra_fields += ordering_fields
381+
query.lookup_pipeline = self.get_lookup_pipeline()
382+
where = self.get_where()
383+
try:
384+
expr = where.as_mql(self, self.connection) if where else {}
385+
except FullResultSet:
386+
query.mongo_query = {}
387+
else:
388+
query.mongo_query = {"$expr": expr}
360389
if extra_fields:
361390
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
362-
where = self.get_where()
363-
try:
364-
expr = where.as_mql(self, self.connection) if where else {}
365-
except FullResultSet:
366-
query.mongo_query = {}
367-
else:
368-
query.mongo_query = {"$expr": expr}
369391
return query
370392

371393
def get_columns(self):
@@ -391,6 +413,9 @@ def project_field(column):
391413
if hasattr(column, "target"):
392414
# column is a Col.
393415
target = column.target.column
416+
# Handle Order By columns as refs columns.
417+
elif isinstance(column, OrderBy) and isinstance(column.expression, Ref):
418+
target = column.expression.refs
394419
else:
395420
# column is a Transform in values()/values_list() that needs a
396421
# name for $proj.
@@ -412,6 +437,75 @@ def collection_name(self):
412437
def collection(self):
413438
return self.connection.get_collection(self.collection_name)
414439

440+
def get_combinator_queries(self):
441+
parts = []
442+
compilers = [
443+
query.get_compiler(self.using, self.connection, self.elide_empty)
444+
for query in self.query.combined_queries
445+
]
446+
main_query_columns = self.get_columns()
447+
main_query_fields, _ = zip(*main_query_columns, strict=True)
448+
for compiler_ in compilers:
449+
try:
450+
# If the columns list is limited, then all combined queries
451+
# must have the same columns list. Set the selects defined on
452+
# the query on all combined queries, if not already set.
453+
if not compiler_.query.values_select and self.query.values_select:
454+
compiler_.query = compiler_.query.clone()
455+
compiler_.query.set_values(
456+
(
457+
*self.query.extra_select,
458+
*self.query.values_select,
459+
*self.query.annotation_select,
460+
)
461+
)
462+
compiler_.pre_sql_setup()
463+
columns = compiler_.get_columns()
464+
parts.append((compiler_.build_query(columns), compiler_, columns))
465+
except EmptyResultSet:
466+
# Omit the empty queryset with UNION.
467+
if self.query.combinator == "union":
468+
continue
469+
raise
470+
# Raise EmptyResultSet if all the combinator queries are empty.
471+
if not parts:
472+
raise EmptyResultSet
473+
# Make the combinator's stages.
474+
combinator_pipeline = None
475+
for part, compiler_, columns in parts:
476+
inner_pipeline = part.get_pipeline()
477+
# Standardize result fields.
478+
fields = {}
479+
# When a .count() is called, the main_query_field has length 1
480+
# otherwise it has the same length as columns.
481+
for alias, (ref, expr) in zip(main_query_fields, columns, strict=False):
482+
if isinstance(expr, Col) and expr.alias != compiler_.collection_name:
483+
fields[expr.alias] = 1
484+
else:
485+
fields[alias] = f"${ref}" if alias != ref else 1
486+
inner_pipeline.append({"$project": fields})
487+
# Combine query with the current combinator pipeline.
488+
if combinator_pipeline:
489+
combinator_pipeline.append(
490+
{"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}}
491+
)
492+
else:
493+
combinator_pipeline = inner_pipeline
494+
if not self.query.combinator_all:
495+
ids = {}
496+
for alias, expr in main_query_columns:
497+
# Unfold foreign fields.
498+
if isinstance(expr, Col) and expr.alias != self.collection_name:
499+
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
500+
else:
501+
ids[alias] = f"${alias}"
502+
combinator_pipeline.append({"$group": {"_id": ids}})
503+
projected_fields = self._fold_columns(ids)
504+
combinator_pipeline.append({"$addFields": projected_fields})
505+
if "_id" not in projected_fields:
506+
combinator_pipeline.append({"$unset": "_id"})
507+
return combinator_pipeline
508+
415509
def get_lookup_pipeline(self):
416510
result = []
417511
for alias in tuple(self.query.alias_map):

django_mongodb/features.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2222
supports_paramstyle_pyformat = False
2323
supports_select_difference = False
2424
supports_select_intersection = False
25-
# Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/72
26-
supports_select_union = False
2725
supports_sequence_reset = False
26+
supports_slicing_ordering_in_compound = True
2827
supports_table_check_constraints = False
2928
supports_temporal_subtraction = True
3029
# MongoDB stores datetimes in UTC.
@@ -234,6 +233,7 @@ def django_test_expected_failures(self):
234233
"Test assumes integer primary key.": {
235234
"db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key",
236235
"model_fields.test_foreignkey.ForeignKeyTests.test_to_python",
236+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_order_raises_on_non_selected_column",
237237
},
238238
"Exists is not supported on MongoDB.": {
239239
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists",
@@ -267,6 +267,7 @@ def django_test_expected_failures(self):
267267
"model_forms.tests.LimitChoicesToTests.test_limit_choices_to_m2m_through",
268268
"model_forms.tests.LimitChoicesToTests.test_limit_choices_to_no_duplicates",
269269
"null_queries.tests.NullQueriesTests.test_reverse_relations",
270+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_values_list_on_annotated_and_unannotated",
270271
"queries.tests.ExcludeTest17600.test_exclude_plain",
271272
"queries.tests.ExcludeTest17600.test_exclude_with_q_is_equal_to_plain_exclude_variation",
272273
"queries.tests.ExcludeTest17600.test_exclude_with_q_object_no_distinct",
@@ -331,6 +332,8 @@ def django_test_expected_failures(self):
331332
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
332333
"model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_on_subquery",
333334
"model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup",
335+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery",
336+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref",
334337
},
335338
"Using a QuerySet in annotate() is not supported on MongoDB.": {
336339
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_reused_subquery",
@@ -368,6 +371,7 @@ def django_test_expected_failures(self):
368371
"model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery",
369372
"one_to_one.tests.OneToOneTests.test_get_prefetch_queryset_warning",
370373
"one_to_one.tests.OneToOneTests.test_rel_pk_subquery",
374+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering",
371375
"queries.tests.CloneTests.test_evaluated_queryset_as_argument",
372376
"queries.tests.DoubleInSubqueryTests.test_double_subquery_in",
373377
"queries.tests.EmptyQuerySetTests.test_values_subquery",
@@ -468,6 +472,8 @@ def django_test_expected_failures(self):
468472
"ordering.tests.OrderingTests.test_extra_ordering",
469473
"ordering.tests.OrderingTests.test_extra_ordering_quoting",
470474
"ordering.tests.OrderingTests.test_extra_ordering_with_table_name",
475+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_order_by_extra_select",
476+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_extra_and_values_list",
471477
"queries.tests.EscapingTests.test_ticket_7302",
472478
"queries.tests.Queries5Tests.test_extra_select_literal_percent_s",
473479
"queries.tests.Queries5Tests.test_ticket7256",

django_mongodb/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, compiler):
5151
self.project_fields = None
5252
self.aggregation_pipeline = compiler.aggregation_pipeline
5353
self.extra_fields = None
54+
self.combinator_pipeline = None
5455

5556
def __repr__(self):
5657
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -78,6 +79,8 @@ def get_pipeline(self):
7879
pipeline.extend(self.aggregation_pipeline)
7980
if self.project_fields:
8081
pipeline.append({"$project": self.project_fields})
82+
if self.combinator_pipeline:
83+
pipeline.extend(self.combinator_pipeline)
8184
if self.extra_fields:
8285
pipeline.append({"$addFields": self.extra_fields})
8386
if self.ordering:

0 commit comments

Comments
 (0)