diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index e64fe6899..ec2817965 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -67,6 +67,7 @@ jobs: - name: Run tests run: > python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 + aggregation annotations auth_tests.test_models.UserManagerTestCase backends.base.test_base.DatabaseWrapperTests diff --git a/README.md b/README.md index 13000e8de..e00512ca4 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,6 @@ Migrations for 'admin': ## Known issues and limitations - The following `QuerySet` methods aren't supported: - - `aggregate()` - `bulk_update()` - `dates()` - `datetimes()` diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index a2dba98ee..7994999df 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -6,12 +6,14 @@ check_django_compatability() +from .aggregates import register_aggregates # noqa: E402 from .expressions import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 +register_aggregates() register_expressions() register_fields() register_functions() diff --git a/django_mongodb/aggregates.py b/django_mongodb/aggregates.py new file mode 100644 index 000000000..1440d8aeb --- /dev/null +++ b/django_mongodb/aggregates.py @@ -0,0 +1,85 @@ +from copy import deepcopy + +from django.db.models.aggregates import Aggregate, Count, StdDev, Variance +from django.db.models.expressions import Case, Value, When +from django.db.models.lookups import Exact +from django.db.models.sql.where import WhereNode + +from .query_utils import process_lhs + +# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower(). +MONGO_AGGREGATIONS = {Count: "sum"} + + +def aggregate( + self, + compiler, + connection, + operator=None, + resolve_inner_expression=False, + **extra_context, # noqa: ARG001 +): + if self.filter: + node = self.copy() + node.filter = None + source_expressions = node.get_source_expressions() + condition = When(self.filter, then=source_expressions[0]) + node.set_source_expressions([Case(condition)] + source_expressions[1:]) + else: + node = self + lhs_mql = process_lhs(node, compiler, connection) + if resolve_inner_expression: + return lhs_mql + operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) + return {f"${operator}": lhs_mql} + + +def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 + """ + When resolve_inner_expression=True, return the MQL that resolves as a + value. This is used to count different elements, so the inner values are + returned to be pushed into a set. + """ + if not self.distinct or resolve_inner_expression: + if self.filter: + node = self.copy() + node.filter = None + source_expressions = node.get_source_expressions() + filter_ = deepcopy(self.filter) + filter_.add( + WhereNode([Exact(source_expressions[0], Value(None))], negated=True), + filter_.default, + ) + condition = When(filter_, then=Value(1)) + node.set_source_expressions([Case(condition)] + source_expressions[1:]) + inner_expression = process_lhs(node, compiler, connection) + else: + lhs_mql = process_lhs(self, compiler, connection) + null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} + inner_expression = { + "$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1} + } + if resolve_inner_expression: + return inner_expression + return {"$sum": inner_expression} + # If distinct=True or resolve_inner_expression=False, sum the size of the + # set. + lhs_mql = process_lhs(self, compiler, connection) + # None shouldn't be counted, so subtract 1 if it's present. + exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} + return {"$add": [{"$size": lhs_mql}, exits_null]} + + +def stddev_variance(self, compiler, connection, **extra_context): + if self.function.endswith("_SAMP"): + operator = "stdDevSamp" + elif self.function.endswith("_POP"): + operator = "stdDevPop" + return aggregate(self, compiler, connection, operator=operator, **extra_context) + + +def register_aggregates(): + Aggregate.as_mql = aggregate + Count.as_mql = count + StdDev.as_mql = stddev_variance + Variance.as_mql = stddev_variance diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 1d2398777..ba2753205 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,10 +1,13 @@ -from itertools import chain +import itertools +from collections import defaultdict from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError from django.db.models import Count, Expression -from django.db.models.aggregates import Aggregate -from django.db.models.expressions import OrderBy +from django.db.models.aggregates import Aggregate, Variance +from django.db.models.expressions import Col, OrderBy, Value +from django.db.models.functions.comparison import Coalesce +from django.db.models.functions.math import Power from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE from django.utils.functional import cached_property @@ -17,15 +20,203 @@ class SQLCompiler(compiler.SQLCompiler): """Base class for all Mongo compilers.""" query_class = MongoQuery + GROUP_SEPARATOR = "___" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.aggregation_pipeline = None + + def _get_group_alias_column(self, expr, annotation_group_idx): + """Generate a dummy field for use in the ids fields in $group.""" + replacement = None + if isinstance(expr, Col): + col = expr + else: + # If the column is a composite expression, create a field for it. + alias = f"__annotation_group{next(annotation_group_idx)}" + col = self._get_column_from_expression(expr, alias) + replacement = col + if self.collection_name == col.alias: + return col.target.column, replacement + # If this is a foreign field, replace the normal dot (.) with + # GROUP_SEPARATOR since FieldPath field names may not contain '.'. + return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}", replacement + + def _get_column_from_expression(self, expr, alias): + """ + Create a column named `alias` from the given expression to hold the + aggregate value. + """ + column_target = expr.output_field.__class__() + column_target.db_column = alias + column_target.set_attributes_from_name(alias) + return Col(self.collection_name, column_target) + + def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx): + """ + Prepare expressions for the aggregation pipeline. + + Handle the computation of aggregation functions used by various + expressions. Separate and create intermediate columns, and replace + nodes to simulate a group by operation. + + MongoDB's $group stage doesn't allow operations over the aggregator, + e.g. COALESCE(AVG(field), 3). However, it supports operations inside + the aggregation, e.g. AVG(number * 2). + + Handle the first case by splitting the computation into stages: compute + the aggregation first, then apply additional operations in a subsequent + stage by replacing the aggregate expressions with new columns prefixed + by `__aggregation`. + """ + replacements = {} + group = {} + for sub_expr in self._get_aggregate_expressions(expression): + alias = ( + f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target + ) + column_target = sub_expr.output_field.__class__() + column_target.db_column = alias + column_target.set_attributes_from_name(alias) + inner_column = Col(self.collection_name, column_target) + if sub_expr.distinct: + # If the expression should return distinct values, use + # $addToSet to deduplicate. + rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) + group[alias] = {"$addToSet": rhs} + replacing_expr = sub_expr.copy() + replacing_expr.set_source_expressions([inner_column]) + else: + group[alias] = sub_expr.as_mql(self, self.connection) + replacing_expr = inner_column + # Count must return 0 rather than null. + if isinstance(sub_expr, Count): + replacing_expr = Coalesce(replacing_expr, 0) + # Variance = StdDev^2 + if isinstance(sub_expr, Variance): + replacing_expr = Power(replacing_expr, 2) + replacements[sub_expr] = replacing_expr + return replacements, group + + def _prepare_annotations_for_aggregation_pipeline(self): + """Prepare annotations for the aggregation pipeline.""" + replacements = {} + group = {} + annotation_group_idx = itertools.count(start=1) + for target, expr in self.query.annotation_select.items(): + if expr.contains_aggregate: + new_replacements, expr_group = self._prepare_expressions_for_pipeline( + expr, target, annotation_group_idx + ) + replacements.update(new_replacements) + group.update(expr_group) + having_replacements, having_group = self._prepare_expressions_for_pipeline( + self.having, None, annotation_group_idx + ) + replacements.update(having_replacements) + group.update(having_group) + return group, replacements + + def _get_group_id_expressions(self, order_by): + """Generate group ID expressions for the aggregation pipeline.""" + group_expressions = set() + replacements = {} + for expr, (_, _, is_ref) in order_by: + if not is_ref: + group_expressions |= set(expr.get_group_by_cols()) + for expr, *_ in self.select: + group_expressions |= set(expr.get_group_by_cols()) + having_group_by = self.having.get_group_by_cols() if self.having else () + for expr in having_group_by: + group_expressions.add(expr) + if isinstance(self.query.group_by, tuple | list): + group_expressions |= set(self.query.group_by) + elif self.query.group_by is None: + group_expressions = set() + if not group_expressions: + ids = None + else: + annotation_group_idx = itertools.count(start=1) + ids = {} + for col in group_expressions: + alias, replacement = self._get_group_alias_column(col, annotation_group_idx) + try: + ids[alias] = col.as_mql(self, self.connection) + except EmptyResultSet: + ids[alias] = Value(False).as_mql(self, self.connection) + except FullResultSet: + ids[alias] = Value(True).as_mql(self, self.connection) + if replacement is not None: + replacements[col] = replacement + return ids, replacements + + def _build_aggregation_pipeline(self, ids, group): + """Build the aggregation pipeline for grouping.""" + pipeline = [] + if not ids: + group["_id"] = None + pipeline.append({"$facet": {"group": [{"$group": group}]}}) + pipeline.append( + { + "$addFields": { + key: { + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": key, + } + } + for key in group + } + } + ) + else: + group["_id"] = ids + pipeline.append({"$group": group}) + projected_fields = defaultdict(dict) + for key in ids: + value = f"$_id.{key}" + if self.GROUP_SEPARATOR in key: + table, field = key.split(self.GROUP_SEPARATOR) + projected_fields[table][field] = value + else: + projected_fields[key] = value + pipeline.append({"$addFields": projected_fields}) + if "_id" not in projected_fields: + pipeline.append({"$unset": "_id"}) + return pipeline + + def pre_sql_setup(self, with_col_aliases=False): + extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) + group, all_replacements = self._prepare_annotations_for_aggregation_pipeline() + # query.group_by is either: + # - None: no GROUP BY + # - True: group by select fields + # - a list of expressions to group by. + if group or self.query.group_by: + ids, replacements = self._get_group_id_expressions(order_by) + all_replacements.update(replacements) + pipeline = self._build_aggregation_pipeline(ids, group) + if self.having: + pipeline.append( + { + "$match": { + "$expr": self.having.replace_expressions(all_replacements).as_mql( + self, self.connection + ) + } + } + ) + self.aggregation_pipeline = pipeline + self.annotations = { + target: expr.replace_expressions(all_replacements) + for target, expr in self.query.annotation_select.items() + } + return extra_select, order_by, group_by def execute_sql( self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE ): self.pre_sql_setup() - # QuerySet.count() - if self.query.annotations == {"__count": Count("*")}: - return [self.get_count()] - columns = self.get_columns() try: query = self.build_query( @@ -77,16 +268,13 @@ def results_iter( fields = [s[0] for s in self.select[0 : self.col_count]] converters = self.get_converters(fields) - rows = chain.from_iterable(results) + rows = itertools.chain.from_iterable(results) if converters: rows = self.apply_converters(rows, converters) if tuple_expected: rows = map(tuple, rows) return rows - def has_results(self): - return bool(self.get_count(check_exists=True)) - def _make_result(self, entity, columns): """ Decode values for the given fields from the database entity. @@ -139,44 +327,21 @@ def check_query(self): if any(key.startswith("_prefetch_related_") for key in self.query.extra): raise NotSupportedError("QuerySet.prefetch_related() is not supported on MongoDB.") raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.") - if any( - isinstance(a, Aggregate) and not isinstance(a, Count) - for a in self.query.annotations.values() - ): - raise NotSupportedError("QuerySet.aggregate() isn't supported on MongoDB.") - - def get_count(self, check_exists=False): - """ - Count objects matching the current filters / constraints. - - If `check_exists` is True, only check if any object matches. - """ - kwargs = {} - # If this query is sliced, the limits will be set on the subquery. - inner_query = getattr(self.query, "inner_query", None) - low_mark = inner_query.low_mark if inner_query else 0 - high_mark = inner_query.high_mark if inner_query else None - if low_mark > 0: - kwargs["skip"] = low_mark - if check_exists: - kwargs["limit"] = 1 - elif high_mark is not None: - kwargs["limit"] = high_mark - low_mark - try: - return self.build_query().count(**kwargs) - except EmptyResultSet: - return 0 def build_query(self, columns=None): """Check if the query is supported and prepare a MongoQuery.""" self.check_query() - query = self.query_class(self, columns) + query = self.query_class(self) query.lookup_pipeline = self.get_lookup_pipeline() + query.order_by(self._get_ordering()) + query.project_fields = self.get_project_fields(columns, ordering=query.ordering) + where = self.get_where() try: - query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)} + expr = where.as_mql(self, self.connection) if where else {} except FullResultSet: query.mongo_query = {} - query.order_by(self._get_ordering()) + else: + query.mongo_query = {"$expr": expr} return query def get_columns(self): @@ -211,7 +376,7 @@ def project_field(column): return ( tuple(map(project_field, columns)) - + tuple(self.query.annotation_select.items()) + + tuple(self.annotations.items()) + tuple(map(project_field, related_columns)) ) @@ -276,6 +441,47 @@ def get_lookup_pipeline(self): result += self.query.alias_map[alias].as_mql(self, self.connection) return result + def _get_aggregate_expressions(self, expr): + stack = [expr] + while stack: + expr = stack.pop() + if isinstance(expr, Aggregate): + yield expr + elif hasattr(expr, "get_source_expressions"): + stack.extend(expr.get_source_expressions()) + + def get_project_fields(self, columns=None, ordering=None): + fields = {} + for name, expr in columns or []: + try: + column = expr.target.column + except AttributeError: + # Generate the MQL for an annotation. + try: + fields[name] = expr.as_mql(self, self.connection) + except EmptyResultSet: + fields[name] = Value(False).as_mql(self, self.connection) + except FullResultSet: + fields[name] = Value(True).as_mql(self, self.connection) + else: + # If name != column, then this is an annotatation referencing + # another column. + fields[name] = 1 if name == column else f"${column}" + if fields: + # Add related fields. + for alias in self.query.alias_map: + if self.query.alias_refcount[alias] and self.collection_name != alias: + fields[alias] = 1 + # Add order_by() fields. + for column, _ in ordering or []: + foreign_table = column.split(".", 1)[0] if "." in column else None + if column not in fields and foreign_table not in fields: + fields[column] = 1 + return fields + + def get_where(self): + return self.where + class SQLInsertCompiler(SQLCompiler): def execute_sql(self, returning_fields=None): @@ -311,7 +517,7 @@ def insert(self, docs, returning_fields=None): class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): def execute_sql(self, result_type=MULTI): cursor = Cursor() - cursor.rowcount = self.build_query([self.query.get_meta().pk]).delete() + cursor.rowcount = self.build_query().delete() return cursor def check_query(self): @@ -321,6 +527,9 @@ def check_query(self): "Cannot use QuerySet.delete() when querying across multiple collections on MongoDB." ) + def get_where(self): + return self.query.where + class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): def execute_sql(self, result_type): @@ -382,6 +591,28 @@ def check_query(self): "Cannot use QuerySet.update() when querying across multiple collections on MongoDB." ) + def get_where(self): + return self.query.where + class SQLAggregateCompiler(SQLCompiler): - pass + def build_query(self, columns=None): + query = self.query_class(self) + query.project_fields = self.get_project_fields(tuple(self.annotations.items())) + compiler = self.query.inner_query.get_compiler( + self.using, + elide_empty=self.elide_empty, + ) + compiler.pre_sql_setup(with_col_aliases=False) + # Avoid $project (columns=None) if unneeded. + columns = ( + compiler.get_columns() + if compiler.query.annotations or not compiler.query.default_cols + else None + ) + subquery = compiler.build_query(columns) + query.subquery = subquery + return query + + def _make_result(self, result, columns=None): + return [result[k] for k in self.query.annotation_select] diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index fc477bed1..a9c5ab9bc 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -13,6 +13,7 @@ NegatedExpression, Ref, ResolvedOuterRef, + Star, Subquery, Value, When, @@ -76,7 +77,11 @@ def query(self, compiler, connection): # noqa: ARG001 def ref(self, compiler, connection): # noqa: ARG001 - return self.refs + return f"${self.refs}" + + +def star(self, compiler, connection): # noqa: ARG001 + return {"$literal": True} def subquery(self, compiler, connection): # noqa: ARG001 @@ -113,6 +118,7 @@ def register_expressions(): Query.as_mql = query Ref.as_mql = ref ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql + Star.as_mql = star Subquery.as_mql = subquery When.as_mql = when Value.as_mql = value diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 4eaff0c23..1af277e3c 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -34,10 +34,21 @@ class DatabaseFeatures(BaseDatabaseFeatures): "basic.tests.SelectOnSaveTests.test_select_on_save_lying_update", # Order by constant not supported: # AttributeError: 'Field' object has no attribute 'model' + "aggregation.tests.AggregateTestCase.test_annotate_values_list", + "aggregation.tests.AggregateTestCase.test_grouped_annotation_in_group_by", + "aggregation.tests.AggregateTestCase.test_non_grouped_annotation_not_in_group_by", + "aggregation.tests.AggregateTestCase.test_values_annotation_with_expression", + "annotations.tests.NonAggregateAnnotationTestCase.test_order_by_aggregate", + "model_fields.test_jsonfield.TestQuerying.test_ordering_grouping_by_count", + "ordering.tests.OrderingTests.test_default_ordering_does_not_affect_group_by", "ordering.tests.OrderingTests.test_order_by_constant_value", "expressions.tests.NegatedExpressionTests.test_filter", "expressions_case.tests.CaseExpressionTests.test_order_by_conditional_implicit", + # BaseExpression.convert_value() crashes with Decimal128. + "aggregation.tests.AggregateTestCase.test_combine_different_types", + "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", # NotSupportedError: order_by() expression not supported. + "aggregation.tests.AggregateTestCase.test_aggregation_order_by_not_selected_annotation_values", "db_functions.comparison.test_coalesce.CoalesceTests.test_ordering", "db_functions.tests.FunctionTests.test_nested_function_ordering", "db_functions.text.test_length.LengthTests.test_ordering", @@ -82,7 +93,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "ordering.tests.OrderingTests.test_order_by_parent_fk_with_expression_in_default_ordering", "ordering.tests.OrderingTests.test_order_by_ptr_field_with_default_ordering_by_expression", "queries.tests.Queries1Tests.test_order_by_tables", - "queries.tests.Queries1Tests.test_ticket4358", "queries.tests.TestTicket24605.test_ticket_24605", "queries.tests.TestInvalidValuesRelation.test_invalid_values", # alias().order_by() doesn't work. @@ -90,11 +100,17 @@ class DatabaseFeatures(BaseDatabaseFeatures): "annotations.tests.AliasTests.test_order_by_alias_aggregate", # annotate() + values_list() + order_by() loses annotated value. "expressions_case.tests.CaseExpressionTests.test_annotate_values_not_in_order_by", - # pymongo.errors.OperationFailure: the limit must be positive - "queries.tests.WeirdQuerysetSlicingTests.test_tickets_7698_10202", # QuerySet.explain() not implemented: # https://github.com/mongodb-labs/django-mongodb/issues/28 "queries.test_explain.ExplainUnsupportedTests.test_message", + # The $sum aggregation returns 0 instead of None for null. + "aggregation.test_filter_argument.FilteredAggregateTests.test_plain_annotate", + "aggregation.tests.AggregateTestCase.test_aggregation_default_passed_another_aggregate", + "aggregation.tests.AggregateTestCase.test_annotation_expressions", + "aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate", + # Incorrect order: pipeline does not order by the correct fields. + "aggregation.tests.AggregateTestCase.test_annotate_ordering", + "aggregation.tests.AggregateTestCase.test_even_more_aggregate", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { @@ -234,42 +250,14 @@ def django_test_expected_failures(self): "db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key", "model_fields.test_foreignkey.ForeignKeyTests.test_to_python", }, - # https://github.com/mongodb-labs/django-mongodb/issues/12 - "QuerySet.aggregate() not supported.": { - "annotations.tests.AliasTests.test_alias_default_alias_expression", - "annotations.tests.AliasTests.test_filter_alias_agg_with_double_f", - "annotations.tests.NonAggregateAnnotationTestCase.test_aggregate_over_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_aggregate_over_full_expression_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_exists_aggregate_values_chaining", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_in_f_grouped_by_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_and_aggregate_values_chaining", - "annotations.tests.NonAggregateAnnotationTestCase.test_filter_agg_with_double_f", - "annotations.tests.NonAggregateAnnotationTestCase.test_values_with_pk_annotation", - "expressions.test_queryset_values.ValuesExpressionsTests.test_chained_values_with_expression", - "expressions.test_queryset_values.ValuesExpressionsTests.test_values_expression_group_by", - "expressions.tests.BasicExpressionsTests.test_annotate_values_aggregate", - "expressions_case.tests.CaseExpressionTests.test_aggregate", - "expressions_case.tests.CaseExpressionTests.test_aggregate_with_expression_as_condition", - "expressions_case.tests.CaseExpressionTests.test_aggregate_with_expression_as_value", - "expressions_case.tests.CaseExpressionTests.test_aggregation_empty_cases", - "expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_condition", - "expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_predicate", - "expressions_case.tests.CaseExpressionTests.test_annotate_with_aggregation_in_value", - "expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause", - "expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_condition", - "expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_predicate", - "expressions_case.tests.CaseExpressionTests.test_filter_with_aggregation_in_value", - "expressions_case.tests.CaseExpressionTests.test_m2m_exclude", - "expressions_case.tests.CaseExpressionTests.test_m2m_reuse", - "lookup.test_decimalfield.DecimalFieldLookupTests", - "lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup", - "from_db_value.tests.FromDBValueTest.test_aggregation", - "timezones.tests.LegacyDatabaseTests.test_query_aggregation", - "timezones.tests.LegacyDatabaseTests.test_query_annotation", - "timezones.tests.NewDatabaseTests.test_query_aggregation", - "timezones.tests.NewDatabaseTests.test_query_annotation", - }, "Exists is not supported on MongoDB.": { + "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_on_exists", + "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_ref_multiple_subquery_annotation", + "aggregation.tests.AggregateTestCase.test_aggregation_exists_multivalued_outeref", + "aggregation.tests.AggregateTestCase.test_group_by_exists_annotation", + "aggregation.tests.AggregateTestCase.test_exists_none_with_aggregate", + "aggregation.tests.AggregateTestCase.test_exists_extra_where_with_aggregate", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_exists_aggregate_values_chaining", "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_exists_none_query", "delete_regress.tests.DeleteTests.test_self_reference_with_through_m2m_at_second_level", "expressions.tests.BasicExpressionsTests.test_annotation_with_deeply_nested_outerref", @@ -314,11 +302,22 @@ def django_test_expected_failures(self): "queries.tests.Ticket22429Tests.test_ticket_22429", }, "Subquery is not supported on MongoDB.": { + "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_ref_subquery_annotation", + "aggregation.tests.AggregateAnnotationPruningTests.test_referenced_composed_subquery_requires_wrapping", + "aggregation.tests.AggregateAnnotationPruningTests.test_referenced_subquery_requires_wrapping", + "aggregation.tests.AggregateTestCase.test_aggregation_nested_subquery_outerref", + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation", + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued", + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_related_field", + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_values", + "aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_values_collision", "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_filter_with_subquery", + "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_and_aggregate_values_chaining", "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_subquery_outerref_transform", "annotations.tests.NonAggregateAnnotationTestCase.test_empty_queryset_annotation", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_outerref", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_subquery_with_parameters", + "expressions.tests.BasicExpressionsTests.test_aggregate_subquery_annotation", "expressions.tests.BasicExpressionsTests.test_annotation_with_nested_outerref", "expressions.tests.BasicExpressionsTests.test_annotation_with_outerref", "expressions.tests.BasicExpressionsTests.test_annotations_within_subquery", @@ -341,10 +340,15 @@ def django_test_expected_failures(self): "model_fields.test_jsonfield.TestQuerying.test_obj_subquery_lookup", }, "Using a QuerySet in annotate() is not supported on MongoDB.": { + "aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_reused_subquery", + "aggregation.tests.AggregateTestCase.test_filter_in_subquery_or_aggregation", + "aggregation.tests.AggregateTestCase.test_group_by_subquery_annotation", + "aggregation.tests.AggregateTestCase.test_group_by_reference_subquery", "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_in_subquery", "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_and_alias_filter_related_in_subquery", "annotations.tests.NonAggregateAnnotationTestCase.test_empty_expression_annotation", "db_functions.comparison.test_coalesce.CoalesceTests.test_empty_queryset", + "expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause", "expressions.tests.FTimeDeltaTests.test_date_subquery_subtraction", "expressions.tests.FTimeDeltaTests.test_datetime_subquery_subtraction", "expressions.tests.FTimeDeltaTests.test_time_subquery_subtraction", @@ -377,30 +381,6 @@ def django_test_expected_failures(self): "queries.tests.WeirdQuerysetSlicingTests.test_empty_sliced_subquery", "queries.tests.WeirdQuerysetSlicingTests.test_empty_sliced_subquery_exclude", }, - # Invalid $project :: caused by :: Unknown expression $count - # https://github.com/mongodb-labs/django-mongodb/issues/79 - "Count() in QuerySet.annotate() crashes.": { - "annotations.tests.AliasTests.test_alias_annotate_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotate_exists", - "annotations.tests.NonAggregateAnnotationTestCase.test_annotate_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_combined_expression_annotation_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_combined_f_expression_annotation_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_full_expression_annotation_with_aggregation", - "annotations.tests.NonAggregateAnnotationTestCase.test_grouping_by_q_expression_annotation", - "annotations.tests.NonAggregateAnnotationTestCase.test_order_by_aggregate", - "annotations.tests.NonAggregateAnnotationTestCase.test_q_expression_annotation_with_aggregation", - "db_functions.comparison.test_cast.CastTests.test_cast_from_db_datetime_to_date_group_by", - "defer_regress.tests.DeferRegressionTest.test_basic", - "defer_regress.tests.DeferRegressionTest.test_defer_annotate_select_related", - "defer_regress.tests.DeferRegressionTest.test_ticket_16409", - "expressions.tests.BasicExpressionsTests.test_aggregate_subquery_annotation", - "expressions.tests.FieldTransformTests.test_month_aggregation", - "expressions_case.tests.CaseDocumentationExamples.test_conditional_aggregation_example", - "model_fields.test_jsonfield.TestQuerying.test_ordering_grouping_by_count", - "ordering.tests.OrderingTests.test_default_ordering_does_not_affect_group_by", - "queries.tests.Queries1Tests.test_ticket_20250", - "queries.tests.ValuesQuerysetTests.test_named_values_list_expression_with_default_alias", - }, "Cannot use QuerySet.delete() when querying across multiple collections on MongoDB.": { "delete.tests.FastDeleteTests.test_fast_delete_aggregation", "delete.tests.FastDeleteTests.test_fast_delete_empty_no_update_can_self_select", @@ -417,6 +397,7 @@ def django_test_expected_failures(self): "queries.tests.Queries5Tests.test_ticket9848", }, "QuerySet.dates() is not supported on MongoDB.": { + "aggregation.tests.AggregateTestCase.test_dates_with_aggregation", "annotations.tests.AliasTests.test_dates_alias", "dates.tests.DatesTests.test_dates_trunc_datetime_fields", "dates.tests.DatesTests.test_related_model_traverse", @@ -434,6 +415,7 @@ def django_test_expected_failures(self): "timezones.tests.NewDatabaseTests.test_query_datetimes_in_other_timezone", }, "QuerySet.distinct() is not supported.": { + "aggregation.tests.AggregateTestCase.test_sum_distinct_aggregate", "lookup.tests.LookupTests.test_lookup_collision_distinct", "queries.tests.ExcludeTest17600.test_exclude_plain_distinct", "queries.tests.ExcludeTest17600.test_exclude_with_q_is_equal_to_plain_exclude", @@ -487,6 +469,11 @@ def django_test_expected_failures(self): "update.tests.AdvancedTests.test_update_annotated_multi_table_queryset", }, "Test inspects query for SQL": { + "aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned", + "aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned", + "aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_pruned", + "aggregation.tests.AggregateAnnotationPruningTests.test_referenced_aggregate_annotation_kept", + "aggregation.tests.AggregateTestCase.test_count_star", "delete.tests.DeletionTests.test_only_referenced_fields_selected", "lookup.tests.LookupTests.test_in_ignore_none", "lookup.tests.LookupTests.test_textfield_exact_null", @@ -494,6 +481,7 @@ def django_test_expected_failures(self): "queries.tests.Queries6Tests.test_col_alias_quoted", }, "Test executes raw SQL.": { + "aggregation.tests.AggregateTestCase.test_coalesced_empty_result_set", "annotations.tests.NonAggregateAnnotationTestCase.test_raw_sql_with_inherited_field", "delete_regress.tests.DeleteLockingTest.test_concurrent_delete", "expressions.tests.BasicExpressionsTests.test_annotate_values_filter", @@ -511,7 +499,10 @@ def django_test_expected_failures(self): "timezones.tests.NewDatabaseTests.test_cursor_explicit_time_zone", "timezones.tests.NewDatabaseTests.test_raw_sql", }, - "Custom functions with SQL don't work on MongoDB.": { + "Custom aggregations/functions with SQL don't work on MongoDB.": { + "aggregation.tests.AggregateTestCase.test_add_implementation", + "aggregation.tests.AggregateTestCase.test_multi_arg_aggregate", + "aggregation.tests.AggregateTestCase.test_empty_result_optimization", "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions", "annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions", }, @@ -540,6 +531,7 @@ def django_test_expected_failures(self): "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_extract_quarter_func_boundaries", }, "TruncDate database function not supported.": { + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_date_from_database", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_date_func", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_date_none", "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_lookup_name_sql_injection", @@ -554,6 +546,7 @@ def django_test_expected_failures(self): "db_functions.datetime.test_extract_trunc.DateFunctionTests.test_trunc_time_none", }, "MongoDB can't annotate ($project) a function like PI().": { + "aggregation.tests.AggregateTestCase.test_aggregation_default_using_decimal_from_database", "db_functions.math.test_pi.PiTests.test", }, "Can't cast from date to datetime without MongoDB interpreting the new value in UTC.": { @@ -585,6 +578,7 @@ def django_test_expected_failures(self): "model_fields.test_jsonfield.TestQuerying.test_none_key_exclude", }, "Randomized ordering isn't supported by MongoDB.": { + "aggregation.tests.AggregateTestCase.test_aggregation_random_ordering", "ordering.tests.OrderingTests.test_random_ordering", }, "Queries without a collection aren't supported on MongoDB.": { diff --git a/django_mongodb/functions.py b/django_mongodb/functions.py index 0753577f5..071f7756a 100644 --- a/django_mongodb/functions.py +++ b/django_mongodb/functions.py @@ -13,6 +13,7 @@ ExtractWeek, ExtractWeekDay, ExtractYear, + Now, TruncBase, ) from django.db.models.functions.math import Ceil, Cot, Degrees, Log, Power, Radians, Random, Round @@ -120,6 +121,10 @@ def log(self, compiler, connection): return func(clone, compiler, connection) +def now(self, compiler, connection): # noqa: ARG001 + return "$$NOW" + + def null_if(self, compiler, connection): """Return None if expr1==expr2 else expr1.""" expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions()) @@ -198,6 +203,7 @@ def register_functions(): Log.as_mql = log Lower.as_mql = perserve_null("toLower") LTrim.as_mql = trim("ltrim") + Now.as_mql = now NullIf.as_mql = null_if Replace.as_mql = replace Round.as_mql = round_ diff --git a/django_mongodb/query.py b/django_mongodb/query.py index c5b4ad7c6..38821cdc6 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -3,7 +3,7 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError -from django.db.models.expressions import Case, Value, When +from django.db.models.expressions import Case, When from django.db.models.functions import Mod from django.db.models.lookups import Exact from django.db.models.sql.constants import INNER @@ -37,31 +37,24 @@ class MongoQuery: built by Django to a "representation" more suitable for MongoDB. """ - def __init__(self, compiler, columns): + def __init__(self, compiler): self.compiler = compiler self.connection = compiler.connection self.ops = compiler.connection.ops self.query = compiler.query - self.columns = columns self._negated = False self.ordering = [] self.collection = self.compiler.get_collection() self.collection_name = self.compiler.collection_name self.mongo_query = getattr(compiler.query, "raw_query", {}) + self.subquery = None self.lookup_pipeline = None + self.project_fields = None + self.aggregation_pipeline = compiler.aggregation_pipeline def __repr__(self): return f"" - @wrap_database_errors - def count(self, limit=None, skip=None): - """ - Return the number of objects that would be returned, if this query was - executed, up to `limit`, skipping `skip`. - """ - result = list(self.get_cursor(count=True, limit=limit, skip=skip)) - return result[0]["__count"] if result else 0 - def order_by(self, ordering): """ Reorder query results or execution order. Called by compiler during @@ -87,58 +80,30 @@ def delete(self): return self.collection.delete_many(self.mongo_query, **options).deleted_count @wrap_database_errors - def get_cursor(self, count=False, limit=None, skip=None): + def get_cursor(self): """ Return a pymongo CommandCursor that can be iterated on to give the results of the query. - - If `count` is True, return a single document with the number of - documents that match the query. - - Use `limit` or `skip` to override those options of the query. """ - fields = {} - for name, expr in self.columns or []: - try: - column = expr.target.column - except AttributeError: - # Generate the MQL for an annotation. - try: - fields[name] = expr.as_mql(self.compiler, self.connection) - except EmptyResultSet: - fields[name] = Value(False).as_mql(self.compiler, self.connection) - except FullResultSet: - fields[name] = Value(True).as_mql(self.compiler, self.connection) - else: - # If name != column, then this is an annotatation referencing - # another column. - fields[name] = 1 if name == column else f"${column}" - if fields: - # Add related fields. - for alias in self.query.alias_map: - if self.query.alias_refcount[alias] and self.collection_name != alias: - fields[alias] = 1 - # Construct the query pipeline. - pipeline = [] + return self.collection.aggregate(self.get_pipeline()) + + def get_pipeline(self): + pipeline = self.subquery.get_pipeline() if self.subquery else [] if self.lookup_pipeline: pipeline.extend(self.lookup_pipeline) if self.mongo_query: pipeline.append({"$match": self.mongo_query}) - if fields: - pipeline.append({"$project": fields}) + if self.aggregation_pipeline: + pipeline.extend(self.aggregation_pipeline) + if self.project_fields: + pipeline.append({"$project": self.project_fields}) if self.ordering: pipeline.append({"$sort": dict(self.ordering)}) - if skip is not None: - pipeline.append({"$skip": skip}) - elif self.query.low_mark > 0: + if self.query.low_mark > 0: pipeline.append({"$skip": self.query.low_mark}) - if limit is not None: - pipeline.append({"$limit": limit}) - elif self.query.high_mark is not None: + if self.query.high_mark is not None: pipeline.append({"$limit": self.query.high_mark - self.query.low_mark}) - if count: - pipeline.append({"$group": {"_id": None, "__count": {"$sum": 1}}}) - return self.collection.aggregate(pipeline) + return pipeline def join(self, compiler, connection): diff --git a/django_mongodb/query_utils.py b/django_mongodb/query_utils.py index 4b616d908..fe3fed9b8 100644 --- a/django_mongodb/query_utils.py +++ b/django_mongodb/query_utils.py @@ -1,4 +1,5 @@ from django.core.exceptions import FullResultSet +from django.db.models.aggregates import Aggregate from django.db.models.expressions import Value @@ -15,6 +16,8 @@ def process_lhs(node, compiler, connection): result.append(expr.as_mql(compiler, connection)) except FullResultSet: result.append(Value(True).as_mql(compiler, connection)) + if isinstance(node, Aggregate): + return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs):