Skip to content

Commit fee311a

Browse files
committed
Support Union.
1 parent cd1434e commit fee311a

File tree

3 files changed

+140
-33
lines changed

3 files changed

+140
-33
lines changed

django_mongodb/compiler.py

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.db import IntegrityError, NotSupportedError
88
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
10-
from django.db.models.expressions import Case, Col, Ref, Value, When
10+
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
1212
from django.db.models.functions.math import Power
1313
from django.db.models.lookups import IsNull
@@ -32,6 +32,33 @@ def __init__(self, *args, **kwargs):
3232
# A list of OrderBy objects for this query.
3333
self.order_by_objs = None
3434

35+
def _unfold_column(self, col):
36+
"""
37+
Flatten a field by returning its target or by replacing dots with GROUP_SEPARATOR
38+
for foreign fields.
39+
"""
40+
if self.collection_name == col.alias:
41+
return col.target.column
42+
# If this is a foreign field, replace the normal dot (.) with
43+
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
44+
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"
45+
46+
def _fold_columns(self, unfold_columns):
47+
"""
48+
Convert flat columns into a nested dictionary, grouping fields by table names.
49+
"""
50+
result = defaultdict(dict)
51+
for key in unfold_columns:
52+
value = f"$_id.{key}"
53+
if self.GROUP_SEPARATOR in key:
54+
table, field = key.split(self.GROUP_SEPARATOR)
55+
result[table][field] = value
56+
else:
57+
result[key] = value
58+
# Convert defaultdict to dict so it doesn't appear as
59+
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
60+
return dict(result)
61+
3562
def _get_group_alias_column(self, expr, annotation_group_idx):
3663
"""Generate a dummy field for use in the ids fields in $group."""
3764
replacement = None
@@ -42,11 +69,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
4269
alias = f"__annotation_group{next(annotation_group_idx)}"
4370
col = self._get_column_from_expression(expr, alias)
4471
replacement = col
45-
if self.collection_name == col.alias:
46-
return col.target.column, replacement
47-
# If this is a foreign field, replace the normal dot (.) with
48-
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
49-
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement
72+
return self._unfold_column(col), replacement
5073

5174
def _get_column_from_expression(self, expr, alias):
5275
"""
@@ -186,17 +209,8 @@ def _build_aggregation_pipeline(self, ids, group):
186209
else:
187210
group["_id"] = ids
188211
pipeline.append({"$group": group})
189-
projected_fields = defaultdict(dict)
190-
for key in ids:
191-
value = f"$_id.{key}"
192-
if self.GROUP_SEPARATOR in key:
193-
table, field = key.split(self.GROUP_SEPARATOR)
194-
projected_fields[table][field] = value
195-
else:
196-
projected_fields[key] = value
197-
# Convert defaultdict to dict so it doesn't appear as
198-
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
199-
pipeline.append({"$addFields": dict(projected_fields)})
212+
projected_fields = self._fold_columns(ids)
213+
pipeline.append({"$addFields": projected_fields})
200214
if "_id" not in projected_fields:
201215
pipeline.append({"$unset": "_id"})
202216
return pipeline
@@ -349,23 +363,30 @@ def build_query(self, columns=None):
349363
"""Check if the query is supported and prepare a MongoQuery."""
350364
self.check_query()
351365
query = self.query_class(self)
352-
query.lookup_pipeline = self.get_lookup_pipeline()
353366
ordering_fields, sort_ordering, extra_fields = self._get_ordering()
354-
query.project_fields = self.get_project_fields(columns, ordering_fields)
355367
query.ordering = sort_ordering
356-
# If columns is None, then get_project_fields() won't add
357-
# ordering_fields to $project. Use $addFields (extra_fields) instead.
358-
if columns is None:
359-
extra_fields += ordering_fields
368+
if self.query.combinator:
369+
if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"):
370+
raise NotSupportedError(
371+
f"{self.query.combinator} is not supported on this database backend."
372+
)
373+
query.combinator_pipeline = self.get_combinator_queries()
374+
else:
375+
query.project_fields = self.get_project_fields(columns, ordering_fields)
376+
# If columns is None, then get_project_fields() won't add
377+
# ordering_fields to $project. Use $addFields (extra_fields) instead.
378+
if columns is None:
379+
extra_fields += ordering_fields
380+
query.lookup_pipeline = self.get_lookup_pipeline()
381+
where = self.get_where()
382+
try:
383+
expr = where.as_mql(self, self.connection) if where else {}
384+
except FullResultSet:
385+
query.mongo_query = {}
386+
else:
387+
query.mongo_query = {"$expr": expr}
360388
if extra_fields:
361389
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
362-
where = self.get_where()
363-
try:
364-
expr = where.as_mql(self, self.connection) if where else {}
365-
except FullResultSet:
366-
query.mongo_query = {}
367-
else:
368-
query.mongo_query = {"$expr": expr}
369390
return query
370391

371392
def get_columns(self):
@@ -391,6 +412,9 @@ def project_field(column):
391412
if hasattr(column, "target"):
392413
# column is a Col.
393414
target = column.target.column
415+
# Handle Order By columns as refs columns.
416+
elif isinstance(column, OrderBy) and isinstance(column.expression, Ref):
417+
target = column.expression.refs
394418
else:
395419
# column is a Transform in values()/values_list() that needs a
396420
# name for $proj.
@@ -412,6 +436,75 @@ def collection_name(self):
412436
def collection(self):
413437
return self.connection.get_collection(self.collection_name)
414438

439+
def get_combinator_queries(self):
440+
parts = []
441+
compilers = [
442+
query.get_compiler(self.using, self.connection, self.elide_empty)
443+
for query in self.query.combined_queries
444+
]
445+
main_query_columns = self.get_columns()
446+
main_query_fields, _ = zip(*main_query_columns, strict=True)
447+
for compiler_ in compilers:
448+
try:
449+
# If the columns list is limited, then all combined queries
450+
# must have the same columns list. Set the selects defined on
451+
# the query on all combined queries, if not already set.
452+
if not compiler_.query.values_select and self.query.values_select:
453+
compiler_.query = compiler_.query.clone()
454+
compiler_.query.set_values(
455+
(
456+
*self.query.extra_select,
457+
*self.query.values_select,
458+
*self.query.annotation_select,
459+
)
460+
)
461+
compiler_.pre_sql_setup()
462+
columns = compiler_.get_columns()
463+
parts.append((compiler_.build_query(columns), compiler_, columns))
464+
except EmptyResultSet:
465+
# Omit the empty queryset with UNION.
466+
if self.query.combinator == "union":
467+
continue
468+
raise
469+
# Raise EmptyResultSet if all the combinator queries are empty.
470+
if not parts:
471+
raise EmptyResultSet
472+
# Make the combinator's stages.
473+
combinator_pipeline = None
474+
for part, compiler_, columns in parts:
475+
inner_pipeline = part.get_pipeline()
476+
# Standardize result fields.
477+
fields = {}
478+
# When a .count() is called, the main_query_field has length 1
479+
# otherwise it has the same length as columns.
480+
for alias, (ref, expr) in zip(main_query_fields, columns, strict=False):
481+
if isinstance(expr, Col) and expr.alias != compiler_.collection_name:
482+
fields[expr.alias] = 1
483+
else:
484+
fields[alias] = f"${ref}" if alias != ref else 1
485+
inner_pipeline.append({"$project": fields})
486+
# Combine query with the current combinator pipeline.
487+
if combinator_pipeline:
488+
combinator_pipeline.append(
489+
{"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}}
490+
)
491+
else:
492+
combinator_pipeline = inner_pipeline
493+
if not self.query.combinator_all:
494+
ids = {}
495+
for alias, expr in main_query_columns:
496+
# Unfold foreign fields.
497+
if isinstance(expr, Col) and expr.alias != self.collection_name:
498+
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
499+
else:
500+
ids[alias] = f"${alias}"
501+
combinator_pipeline.append({"$group": {"_id": ids}})
502+
projected_fields = self._fold_columns(ids)
503+
combinator_pipeline.append({"$addFields": projected_fields})
504+
if "_id" not in projected_fields:
505+
combinator_pipeline.append({"$unset": "_id"})
506+
return combinator_pipeline
507+
415508
def get_lookup_pipeline(self):
416509
result = []
417510
for alias in tuple(self.query.alias_map):

django_mongodb/features.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2222
supports_paramstyle_pyformat = False
2323
supports_select_difference = False
2424
supports_select_intersection = False
25-
# Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/72
26-
supports_select_union = False
2725
supports_sequence_reset = False
26+
supports_slicing_ordering_in_compound = True
2827
supports_table_check_constraints = False
2928
supports_temporal_subtraction = True
3029
# MongoDB stores datetimes in UTC.
@@ -255,6 +254,7 @@ def django_test_expected_failures(self):
255254
"Test assumes integer primary key.": {
256255
"db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key",
257256
"model_fields.test_foreignkey.ForeignKeyTests.test_to_python",
257+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_order_raises_on_non_selected_column",
258258
},
259259
"Exists is not supported on MongoDB.": {
260260
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists",
@@ -288,6 +288,7 @@ def django_test_expected_failures(self):
288288
"model_forms.tests.LimitChoicesToTests.test_limit_choices_to_m2m_through",
289289
"model_forms.tests.LimitChoicesToTests.test_limit_choices_to_no_duplicates",
290290
"null_queries.tests.NullQueriesTests.test_reverse_relations",
291+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_values_list_on_annotated_and_unannotated",
291292
"queries.tests.ExcludeTest17600.test_exclude_plain",
292293
"queries.tests.ExcludeTest17600.test_exclude_with_q_is_equal_to_plain_exclude_variation",
293294
"queries.tests.ExcludeTest17600.test_exclude_with_q_object_no_distinct",
@@ -352,6 +353,8 @@ def django_test_expected_failures(self):
352353
"lookup.tests.LookupQueryingTests.test_filter_subquery_lhs",
353354
"model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_on_subquery",
354355
"model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup",
356+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery",
357+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref",
355358
},
356359
"Using a QuerySet in annotate() is not supported on MongoDB.": {
357360
"aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_reused_subquery",
@@ -389,6 +392,7 @@ def django_test_expected_failures(self):
389392
"model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery",
390393
"one_to_one.tests.OneToOneTests.test_get_prefetch_queryset_warning",
391394
"one_to_one.tests.OneToOneTests.test_rel_pk_subquery",
395+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering",
392396
"queries.tests.CloneTests.test_evaluated_queryset_as_argument",
393397
"queries.tests.DoubleInSubqueryTests.test_double_subquery_in",
394398
"queries.tests.EmptyQuerySetTests.test_values_subquery",
@@ -489,6 +493,8 @@ def django_test_expected_failures(self):
489493
"ordering.tests.OrderingTests.test_extra_ordering",
490494
"ordering.tests.OrderingTests.test_extra_ordering_quoting",
491495
"ordering.tests.OrderingTests.test_extra_ordering_with_table_name",
496+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_order_by_extra_select",
497+
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_extra_and_values_list",
492498
"queries.tests.EscapingTests.test_ticket_7302",
493499
"queries.tests.Queries5Tests.test_extra_select_literal_percent_s",
494500
"queries.tests.Queries5Tests.test_ticket7256",

django_mongodb/query.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.core.exceptions import EmptyResultSet, FullResultSet
55
from django.db import DatabaseError, IntegrityError
6-
from django.db.models.expressions import Case, When
6+
from django.db.models.expressions import Case, OrderBy, When
77
from django.db.models.functions import Mod
88
from django.db.models.lookups import Exact
99
from django.db.models.sql.constants import INNER
@@ -51,6 +51,7 @@ def __init__(self, compiler):
5151
self.project_fields = None
5252
self.aggregation_pipeline = compiler.aggregation_pipeline
5353
self.extra_fields = None
54+
self.combinator_pipeline = None
5455

5556
def __repr__(self):
5657
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -78,6 +79,8 @@ def get_pipeline(self):
7879
pipeline.extend(self.aggregation_pipeline)
7980
if self.project_fields:
8081
pipeline.append({"$project": self.project_fields})
82+
if self.combinator_pipeline:
83+
pipeline.extend(self.combinator_pipeline)
8184
if self.extra_fields:
8285
pipeline.append({"$addFields": self.extra_fields})
8386
if self.ordering:
@@ -166,6 +169,10 @@ def join(self, compiler, connection):
166169
return lookup_pipeline
167170

168171

172+
def order_by(self, compiler, connection):
173+
return self.expression.as_mql(compiler, connection)
174+
175+
169176
def where_node(self, compiler, connection):
170177
if self.connector == AND:
171178
full_needed, empty_needed = len(self.children), 1
@@ -231,4 +238,5 @@ def where_node(self, compiler, connection):
231238
def register_nodes():
232239
Join.as_mql = join
233240
NothingNode.as_mql = NothingNode.as_sql
241+
OrderBy.as_mql = order_by
234242
WhereNode.as_mql = where_node

0 commit comments

Comments
 (0)