diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index d84790bf4..7081b2f0a 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -3,9 +3,9 @@ from collections import defaultdict from bson import SON -from django.core.exceptions import EmptyResultSet, FullResultSet -from django.db import DatabaseError, IntegrityError, NotSupportedError -from django.db.models import Count, Expression +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet +from django.db import IntegrityError, NotSupportedError +from django.db.models import Count from django.db.models.aggregates import Aggregate, Variance from django.db.models.expressions import Case, Col, Ref, Value, When from django.db.models.functions.comparison import Coalesce @@ -581,7 +581,19 @@ def execute_sql(self, result_type): self.pre_sql_setup() values = [] for field, _, value in self.query.values: - if hasattr(value, "prepare_database_save"): + if hasattr(value, "resolve_expression"): + value = value.resolve_expression(self.query, allow_joins=False, for_save=True) + if value.contains_aggregate: + raise FieldError( + "Aggregate functions are not allowed in this query " + f"({field.name}={value})." + ) + if value.contains_over_clause: + raise FieldError( + "Window expressions are not allowed in this query " + f"({field.name}={value})." + ) + elif hasattr(value, "prepare_database_save"): if field.remote_field: value = value.prepare_database_save(field) else: @@ -591,9 +603,19 @@ def execute_sql(self, result_type): f"{field.__class__.__name__}." ) prepared = field.get_db_prep_save(value, connection=self.connection) - values.append((field, prepared)) + if hasattr(value, "as_mql"): + prepared = prepared.as_mql(self, self.connection) + values.append((field.column, prepared)) + try: + criteria = self.build_query().mongo_query + except EmptyResultSet: + return 0 is_empty = not bool(values) - rows = 0 if is_empty else self.update(values) + if is_empty: + rows = 0 + else: + rows = self.collection.update_many(criteria, [{"$set": dict(values)}]).matched_count + for query in self.query.get_related_updates(): aux_rows = query.get_compiler(self.using).execute_sql(result_type) if is_empty and aux_rows: @@ -601,32 +623,6 @@ def execute_sql(self, result_type): is_empty = False return rows - def update(self, values): - spec = {} - for field, value in values: - if field.primary_key: - raise DatabaseError("Cannot modify _id.") - if isinstance(value, Expression): - raise NotSupportedError("QuerySet.update() with expression not supported.") - # .update(foo=123) --> {'$set': {'foo': 123}} - spec.setdefault("$set", {})[field.column] = value - return self.execute_update(spec) - - @wrap_database_errors - def execute_update(self, update_spec): - try: - criteria = self.build_query().mongo_query - except EmptyResultSet: - return 0 - return self.collection.update_many(criteria, update_spec).matched_count - - def check_query(self): - super().check_query() - if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1: - raise NotSupportedError( - "Cannot use QuerySet.update() when querying across multiple collections on MongoDB." - ) - def get_where(self): return self.query.where diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 41b64a5d9..bbf1375fd 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -105,68 +105,21 @@ def django_test_expected_failures(self): "m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_reverse_proxied", }, "QuerySet.update() with expression not supported.": { - "annotations.tests.AliasTests.test_update_with_alias", - "annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation", - "db_functions.comparison.test_least.LeastTests.test_update", - "db_functions.comparison.test_greatest.GreatestTests.test_update", - "db_functions.text.test_left.LeftTests.test_basic", - "db_functions.text.test_lower.LowerTests.test_basic", - "db_functions.text.test_replace.ReplaceTests.test_update", - "db_functions.text.test_substr.SubstrTests.test_basic", - "db_functions.text.test_upper.UpperTests.test_basic", - "expressions.tests.BasicExpressionsTests.test_arithmetic", - "expressions.tests.BasicExpressionsTests.test_filter_with_join", - "expressions.tests.BasicExpressionsTests.test_object_update", - "expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects", - "expressions.tests.BasicExpressionsTests.test_order_of_operations", - "expressions.tests.BasicExpressionsTests.test_parenthesis_priority", - "expressions.tests.BasicExpressionsTests.test_update", - "expressions.tests.BasicExpressionsTests.test_update_with_fk", - "expressions.tests.BasicExpressionsTests.test_update_with_none", - "expressions.tests.ExpressionsNumericTests.test_decimal_expression", - "expressions.tests.ExpressionsNumericTests.test_increment_value", - "expressions.tests.FTimeDeltaTests.test_delta_update", - "expressions.tests.FTimeDeltaTests.test_negative_timedelta_update", - "expressions.tests.ValueTests.test_update_TimeField_using_Value", "expressions.tests.ValueTests.test_update_UUIDField_using_Value", - "expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example", - "expressions_case.tests.CaseExpressionTests.test_update", - "expressions_case.tests.CaseExpressionTests.test_update_big_integer", - "expressions_case.tests.CaseExpressionTests.test_update_binary", - "expressions_case.tests.CaseExpressionTests.test_update_boolean", - "expressions_case.tests.CaseExpressionTests.test_update_date", - "expressions_case.tests.CaseExpressionTests.test_update_date_time", - "expressions_case.tests.CaseExpressionTests.test_update_decimal", - "expressions_case.tests.CaseExpressionTests.test_update_duration", - "expressions_case.tests.CaseExpressionTests.test_update_email", - "expressions_case.tests.CaseExpressionTests.test_update_file", - "expressions_case.tests.CaseExpressionTests.test_update_file_path", - "expressions_case.tests.CaseExpressionTests.test_update_fk", - "expressions_case.tests.CaseExpressionTests.test_update_float", - "expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address", - "expressions_case.tests.CaseExpressionTests.test_update_image", - "expressions_case.tests.CaseExpressionTests.test_update_null_boolean", - "expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer", - "expressions_case.tests.CaseExpressionTests.test_update_positive_integer", - "expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer", - "expressions_case.tests.CaseExpressionTests.test_update_slug", - "expressions_case.tests.CaseExpressionTests.test_update_small_integer", - "expressions_case.tests.CaseExpressionTests.test_update_string", - "expressions_case.tests.CaseExpressionTests.test_update_text", - "expressions_case.tests.CaseExpressionTests.test_update_time", - "expressions_case.tests.CaseExpressionTests.test_update_url", "expressions_case.tests.CaseExpressionTests.test_update_uuid", - "expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition", - "expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value", - "expressions_case.tests.CaseExpressionTests.test_update_without_default", + "expressions.tests.FTimeDeltaTests.test_negative_timedelta_update", + "expressions.tests.BasicExpressionsTests.test_filter_with_join", + "expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects", "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", "queries.test_bulk_update.BulkUpdateNoteTests", - "queries.test_bulk_update.BulkUpdateTests", - "timezones.tests.NewDatabaseTests.test_update_with_timedelta", - "update.tests.AdvancedTests.test_update_annotated_queryset", - "update.tests.AdvancedTests.test_update_negated_f", - "update.tests.AdvancedTests.test_update_negated_f_conditional_annotation", - "update.tests.AdvancedTests.test_update_transformed_field", + "queries.test_bulk_update.BulkUpdateTests.test_database_routing", + "delete_regress.tests.SetQueryCountTests.test_set_querycount", + "delete.tests.FastDeleteTests.test_fast_delete_qs", + "delete.tests.FastDeleteTests.test_fast_delete_instance_set_pk_none", + "delete.tests.FastDeleteTests.test_fast_delete_large_batch", + "delete.tests.FastDeleteTests.test_fast_delete_fk", + "delete.tests.DeletionTests.test_cannot_defer_constraint_checks", + "basic.tests.SelectOnSaveTests.test_select_on_save", }, "AutoField not supported.": { "bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",