diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index d975e4446..282de7c79 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -7,7 +7,7 @@ from django.db import IntegrityError, NotSupportedError from django.db.models import Count from django.db.models.aggregates import Aggregate, Variance -from django.db.models.expressions import Case, Col, Ref, Value, When +from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power from django.db.models.lookups import IsNull @@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs): # A list of OrderBy objects for this query. self.order_by_objs = None + def _unfold_column(self, col): + """ + Flatten a field by returning its target or by replacing dots with + GROUP_SEPARATOR for foreign fields. + """ + if self.collection_name == col.alias: + return col.target.column + # If this is a foreign field, replace the normal dot (.) with + # GROUP_SEPARATOR since FieldPath field names may not contain '.'. + return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}" + + def _fold_columns(self, unfold_columns): + """ + Convert flat columns into a nested dictionary, grouping fields by + table name. + """ + result = defaultdict(dict) + for key in unfold_columns: + value = f"$_id.{key}" + if self.GROUP_SEPARATOR in key: + table, field = key.split(self.GROUP_SEPARATOR) + result[table][field] = value + else: + result[key] = value + # Convert defaultdict to dict so it doesn't appear as + # "defaultdict(, ..." in query logging. + return dict(result) + def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" replacement = None @@ -42,11 +70,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx): alias = f"__annotation_group{next(annotation_group_idx)}" col = self._get_column_from_expression(expr, alias) replacement = col - if self.collection_name == col.alias: - return col.target.column, replacement - # If this is a foreign field, replace the normal dot (.) with - # GROUP_SEPARATOR since FieldPath field names may not contain '.'. - return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement + return self._unfold_column(col), replacement def _get_column_from_expression(self, expr, alias): """ @@ -186,17 +210,8 @@ def _build_aggregation_pipeline(self, ids, group): else: group["_id"] = ids pipeline.append({"$group": group}) - projected_fields = defaultdict(dict) - for key in ids: - value = f"$_id.{key}" - if self.GROUP_SEPARATOR in key: - table, field = key.split(self.GROUP_SEPARATOR) - projected_fields[table][field] = value - else: - projected_fields[key] = value - # Convert defaultdict to dict so it doesn't appear as - # "defaultdict(, ..." in query logging. - pipeline.append({"$addFields": dict(projected_fields)}) + projected_fields = self._fold_columns(ids) + pipeline.append({"$addFields": projected_fields}) if "_id" not in projected_fields: pipeline.append({"$unset": "_id"}) return pipeline @@ -349,23 +364,30 @@ def build_query(self, columns=None): """Check if the query is supported and prepare a MongoQuery.""" self.check_query() query = self.query_class(self) - query.lookup_pipeline = self.get_lookup_pipeline() ordering_fields, sort_ordering, extra_fields = self._get_ordering() - query.project_fields = self.get_project_fields(columns, ordering_fields) query.ordering = sort_ordering - # If columns is None, then get_project_fields() won't add - # ordering_fields to $project. Use $addFields (extra_fields) instead. - if columns is None: - extra_fields += ordering_fields + if self.query.combinator: + if not getattr(self.connection.features, f"supports_select_{self.query.combinator}"): + raise NotSupportedError( + f"{self.query.combinator} is not supported on this database backend." + ) + query.combinator_pipeline = self.get_combinator_queries() + else: + query.project_fields = self.get_project_fields(columns, ordering_fields) + # If columns is None, then get_project_fields() won't add + # ordering_fields to $project. Use $addFields (extra_fields) instead. + if columns is None: + extra_fields += ordering_fields + query.lookup_pipeline = self.get_lookup_pipeline() + where = self.get_where() + try: + expr = where.as_mql(self, self.connection) if where else {} + except FullResultSet: + query.mongo_query = {} + else: + query.mongo_query = {"$expr": expr} if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) - where = self.get_where() - try: - expr = where.as_mql(self, self.connection) if where else {} - except FullResultSet: - query.mongo_query = {} - else: - query.mongo_query = {"$expr": expr} return query def get_columns(self): @@ -391,6 +413,9 @@ def project_field(column): if hasattr(column, "target"): # column is a Col. target = column.target.column + # Handle Order By columns as refs columns. + elif isinstance(column, OrderBy) and isinstance(column.expression, Ref): + target = column.expression.refs else: # column is a Transform in values()/values_list() that needs a # name for $proj. @@ -412,6 +437,75 @@ def collection_name(self): def collection(self): return self.connection.get_collection(self.collection_name) + def get_combinator_queries(self): + parts = [] + compilers = [ + query.get_compiler(self.using, self.connection, self.elide_empty) + for query in self.query.combined_queries + ] + main_query_columns = self.get_columns() + main_query_fields, _ = zip(*main_query_columns, strict=True) + for compiler_ in compilers: + try: + # If the columns list is limited, then all combined queries + # must have the same columns list. Set the selects defined on + # the query on all combined queries, if not already set. + if not compiler_.query.values_select and self.query.values_select: + compiler_.query = compiler_.query.clone() + compiler_.query.set_values( + ( + *self.query.extra_select, + *self.query.values_select, + *self.query.annotation_select, + ) + ) + compiler_.pre_sql_setup() + columns = compiler_.get_columns() + parts.append((compiler_.build_query(columns), compiler_, columns)) + except EmptyResultSet: + # Omit the empty queryset with UNION. + if self.query.combinator == "union": + continue + raise + # Raise EmptyResultSet if all the combinator queries are empty. + if not parts: + raise EmptyResultSet + # Make the combinator's stages. + combinator_pipeline = None + for part, compiler_, columns in parts: + inner_pipeline = part.get_pipeline() + # Standardize result fields. + fields = {} + # When a .count() is called, the main_query_field has length 1 + # otherwise it has the same length as columns. + for alias, (ref, expr) in zip(main_query_fields, columns, strict=False): + if isinstance(expr, Col) and expr.alias != compiler_.collection_name: + fields[expr.alias] = 1 + else: + fields[alias] = f"${ref}" if alias != ref else 1 + inner_pipeline.append({"$project": fields}) + # Combine query with the current combinator pipeline. + if combinator_pipeline: + combinator_pipeline.append( + {"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}} + ) + else: + combinator_pipeline = inner_pipeline + if not self.query.combinator_all: + ids = {} + for alias, expr in main_query_columns: + # Unfold foreign fields. + if isinstance(expr, Col) and expr.alias != self.collection_name: + ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection) + else: + ids[alias] = f"${alias}" + combinator_pipeline.append({"$group": {"_id": ids}}) + projected_fields = self._fold_columns(ids) + combinator_pipeline.append({"$addFields": projected_fields}) + if "_id" not in projected_fields: + combinator_pipeline.append({"$unset": "_id"}) + return combinator_pipeline + def get_lookup_pipeline(self): result = [] for alias in tuple(self.query.alias_map): diff --git a/django_mongodb/features.py b/django_mongodb/features.py index a275cc3bd..441647901 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -22,9 +22,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_paramstyle_pyformat = False supports_select_difference = False supports_select_intersection = False - # Not implemented: https://github.com/mongodb-labs/django-mongodb/issues/72 - supports_select_union = False supports_sequence_reset = False + supports_slicing_ordering_in_compound = True supports_table_check_constraints = False supports_temporal_subtraction = True # MongoDB stores datetimes in UTC. @@ -234,6 +233,7 @@ def django_test_expected_failures(self): "Test assumes integer primary key.": { "db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key", "model_fields.test_foreignkey.ForeignKeyTests.test_to_python", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_order_raises_on_non_selected_column", }, "Exists is not supported on MongoDB.": { "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists", @@ -267,6 +267,7 @@ def django_test_expected_failures(self): "model_forms.tests.LimitChoicesToTests.test_limit_choices_to_m2m_through", "model_forms.tests.LimitChoicesToTests.test_limit_choices_to_no_duplicates", "null_queries.tests.NullQueriesTests.test_reverse_relations", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_values_list_on_annotated_and_unannotated", "queries.tests.ExcludeTest17600.test_exclude_plain", "queries.tests.ExcludeTest17600.test_exclude_with_q_is_equal_to_plain_exclude_variation", "queries.tests.ExcludeTest17600.test_exclude_with_q_object_no_distinct", @@ -331,6 +332,8 @@ def django_test_expected_failures(self): "lookup.tests.LookupQueryingTests.test_filter_subquery_lhs", "model_fields.test_jsonfield.TestQuerying.test_nested_key_transform_on_subquery", "model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref", }, "Using a QuerySet in annotate() is not supported on MongoDB.": { "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_reused_subquery", @@ -368,6 +371,7 @@ def django_test_expected_failures(self): "model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery", "one_to_one.tests.OneToOneTests.test_get_prefetch_queryset_warning", "one_to_one.tests.OneToOneTests.test_rel_pk_subquery", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering", "queries.tests.CloneTests.test_evaluated_queryset_as_argument", "queries.tests.DoubleInSubqueryTests.test_double_subquery_in", "queries.tests.EmptyQuerySetTests.test_values_subquery", @@ -468,6 +472,8 @@ def django_test_expected_failures(self): "ordering.tests.OrderingTests.test_extra_ordering", "ordering.tests.OrderingTests.test_extra_ordering_quoting", "ordering.tests.OrderingTests.test_extra_ordering_with_table_name", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_multiple_models_with_values_list_and_order_by_extra_select", + "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_with_extra_and_values_list", "queries.tests.EscapingTests.test_ticket_7302", "queries.tests.Queries5Tests.test_extra_select_literal_percent_s", "queries.tests.Queries5Tests.test_ticket7256", diff --git a/django_mongodb/query.py b/django_mongodb/query.py index a2ac7cc3e..7a7848cd0 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -51,6 +51,7 @@ def __init__(self, compiler): self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline self.extra_fields = None + self.combinator_pipeline = None def __repr__(self): return f"" @@ -78,6 +79,8 @@ def get_pipeline(self): pipeline.extend(self.aggregation_pipeline) if self.project_fields: pipeline.append({"$project": self.project_fields}) + if self.combinator_pipeline: + pipeline.extend(self.combinator_pipeline) if self.extra_fields: pipeline.append({"$addFields": self.extra_fields}) if self.ordering: