diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 3a6efbe3e..38ed109e1 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -3,9 +3,9 @@ from collections import defaultdict from bson import SON -from django.core.exceptions import EmptyResultSet, FullResultSet -from django.db import DatabaseError, IntegrityError, NotSupportedError -from django.db.models import Count, Expression +from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet +from django.db import IntegrityError, NotSupportedError +from django.db.models import Count from django.db.models.aggregates import Aggregate, Variance from django.db.models.expressions import Case, Col, Ref, Value, When from django.db.models.functions.comparison import Coalesce @@ -576,9 +576,21 @@ def execute_sql(self, result_type): related queries are not available. """ self.pre_sql_setup() - values = [] + values = {} for field, _, value in self.query.values: - if hasattr(value, "prepare_database_save"): + if hasattr(value, "resolve_expression"): + value = value.resolve_expression(self.query, allow_joins=False, for_save=True) + if value.contains_aggregate: + raise FieldError( + "Aggregate functions are not allowed in this query " + f"({field.name}={value})." + ) + if value.contains_over_clause: + raise FieldError( + "Window expressions are not allowed in this query " + f"({field.name}={value})." + ) + elif hasattr(value, "prepare_database_save"): if field.remote_field: value = value.prepare_database_save(field) else: @@ -588,9 +600,15 @@ def execute_sql(self, result_type): f"{field.__class__.__name__}." ) prepared = field.get_db_prep_save(value, connection=self.connection) - values.append((field, prepared)) + if hasattr(value, "as_mql"): + prepared = prepared.as_mql(self, self.connection) + values[field.column] = prepared + try: + criteria = self.build_query().mongo_query + except EmptyResultSet: + return 0 is_empty = not bool(values) - rows = 0 if is_empty else self.update(values) + rows = 0 if is_empty else self.update(criteria, [{"$set": values}]) for query in self.query.get_related_updates(): aux_rows = query.get_compiler(self.using).execute_sql(result_type) if is_empty and aux_rows: @@ -598,24 +616,9 @@ def execute_sql(self, result_type): is_empty = False return rows - def update(self, values): - spec = {} - for field, value in values: - if field.primary_key: - raise DatabaseError("Cannot modify _id.") - if isinstance(value, Expression): - raise NotSupportedError("QuerySet.update() with expression not supported.") - # .update(foo=123) --> {'$set': {'foo': 123}} - spec.setdefault("$set", {})[field.column] = value - return self.execute_update(spec) - @wrap_database_errors - def execute_update(self, update_spec): - try: - criteria = self.build_query().mongo_query - except EmptyResultSet: - return 0 - return self.collection.update_many(criteria, update_spec).matched_count + def update(self, criteria, pipeline): + return self.collection.update_many(criteria, pipeline).matched_count def check_query(self): super().check_query() diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 9012279ea..3b2eea2d4 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -1,5 +1,6 @@ import datetime from decimal import Decimal +from uuid import UUID from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet @@ -110,6 +111,8 @@ def value(self, compiler, connection): # noqa: ARG001 elif isinstance(value, datetime.timedelta): # DurationField stores milliseconds rather than microseconds. value /= datetime.timedelta(milliseconds=1) + elif isinstance(value, UUID): + value = value.hex return {"$literal": value} diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 145c3049b..2aa8217dc 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -99,6 +99,7 @@ def django_test_expected_failures(self): "expressions.tests.BasicExpressionsTests.test_new_object_save", "expressions.tests.BasicExpressionsTests.test_object_create_with_aggregate", "expressions.tests.BasicExpressionsTests.test_object_create_with_f_expression_in_subquery", + "expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects", # PI() "db_functions.math.test_round.RoundTests.test_decimal_with_precision", "db_functions.math.test_round.RoundTests.test_float_with_precision", @@ -124,70 +125,6 @@ def django_test_expected_failures(self): "many_to_many.tests.ManyToManyTests.test_set_after_prefetch", "model_forms.tests.OtherModelFormTests.test_prefetch_related_queryset", }, - "QuerySet.update() with expression not supported.": { - "annotations.tests.AliasTests.test_update_with_alias", - "annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation", - "db_functions.comparison.test_least.LeastTests.test_update", - "db_functions.comparison.test_greatest.GreatestTests.test_update", - "db_functions.text.test_left.LeftTests.test_basic", - "db_functions.text.test_lower.LowerTests.test_basic", - "db_functions.text.test_replace.ReplaceTests.test_update", - "db_functions.text.test_substr.SubstrTests.test_basic", - "db_functions.text.test_upper.UpperTests.test_basic", - "expressions.tests.BasicExpressionsTests.test_arithmetic", - "expressions.tests.BasicExpressionsTests.test_filter_with_join", - "expressions.tests.BasicExpressionsTests.test_object_update", - "expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects", - "expressions.tests.BasicExpressionsTests.test_order_of_operations", - "expressions.tests.BasicExpressionsTests.test_parenthesis_priority", - "expressions.tests.BasicExpressionsTests.test_update", - "expressions.tests.BasicExpressionsTests.test_update_with_fk", - "expressions.tests.BasicExpressionsTests.test_update_with_none", - "expressions.tests.ExpressionsNumericTests.test_decimal_expression", - "expressions.tests.ExpressionsNumericTests.test_increment_value", - "expressions.tests.FTimeDeltaTests.test_delta_update", - "expressions.tests.FTimeDeltaTests.test_negative_timedelta_update", - "expressions.tests.ValueTests.test_update_TimeField_using_Value", - "expressions.tests.ValueTests.test_update_UUIDField_using_Value", - "expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example", - "expressions_case.tests.CaseExpressionTests.test_update", - "expressions_case.tests.CaseExpressionTests.test_update_big_integer", - "expressions_case.tests.CaseExpressionTests.test_update_binary", - "expressions_case.tests.CaseExpressionTests.test_update_boolean", - "expressions_case.tests.CaseExpressionTests.test_update_date", - "expressions_case.tests.CaseExpressionTests.test_update_date_time", - "expressions_case.tests.CaseExpressionTests.test_update_decimal", - "expressions_case.tests.CaseExpressionTests.test_update_duration", - "expressions_case.tests.CaseExpressionTests.test_update_email", - "expressions_case.tests.CaseExpressionTests.test_update_file", - "expressions_case.tests.CaseExpressionTests.test_update_file_path", - "expressions_case.tests.CaseExpressionTests.test_update_fk", - "expressions_case.tests.CaseExpressionTests.test_update_float", - "expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address", - "expressions_case.tests.CaseExpressionTests.test_update_image", - "expressions_case.tests.CaseExpressionTests.test_update_null_boolean", - "expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer", - "expressions_case.tests.CaseExpressionTests.test_update_positive_integer", - "expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer", - "expressions_case.tests.CaseExpressionTests.test_update_slug", - "expressions_case.tests.CaseExpressionTests.test_update_small_integer", - "expressions_case.tests.CaseExpressionTests.test_update_string", - "expressions_case.tests.CaseExpressionTests.test_update_text", - "expressions_case.tests.CaseExpressionTests.test_update_time", - "expressions_case.tests.CaseExpressionTests.test_update_url", - "expressions_case.tests.CaseExpressionTests.test_update_uuid", - "expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition", - "expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value", - "expressions_case.tests.CaseExpressionTests.test_update_without_default", - "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", - "queries.test_bulk_update.BulkUpdateNoteTests", - "queries.test_bulk_update.BulkUpdateTests", - "timezones.tests.NewDatabaseTests.test_update_with_timedelta", - "update.tests.AdvancedTests.test_update_annotated_queryset", - "update.tests.AdvancedTests.test_update_negated_f", - "update.tests.AdvancedTests.test_update_negated_f_conditional_annotation", - "update.tests.AdvancedTests.test_update_transformed_field", - }, "AutoField not supported.": { "bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields", "lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform", @@ -216,6 +153,9 @@ def django_test_expected_failures(self): "one_to_one.tests.OneToOneTests.test_multiple_o2o", "queries.test_bulk_update.BulkUpdateTests.test_database_routing_batch_atomicity", }, + "MongoDB does not enforce PositiveIntegerField constraint.": { + "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", + }, "Test assumes integer primary key.": { "db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key", "model_fields.test_foreignkey.ForeignKeyTests.test_to_python", @@ -333,10 +273,10 @@ def django_test_expected_failures(self): "aggregation_regress.tests.AggregationTests.test_more_more_more5", "aggregation_regress.tests.AggregationTests.test_negated_aggregation", "db_functions.comparison.test_coalesce.CoalesceTests.test_empty_queryset", - "expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause", "expressions.tests.FTimeDeltaTests.test_date_subquery_subtraction", "expressions.tests.FTimeDeltaTests.test_datetime_subquery_subtraction", "expressions.tests.FTimeDeltaTests.test_time_subquery_subtraction", + "expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause", "expressions_case.tests.CaseExpressionTests.test_in_subquery", "lookup.tests.LookupTests.test_exact_query_rhs_with_selected_columns", "lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one", @@ -386,8 +326,10 @@ def django_test_expected_failures(self): "one_to_one.tests.OneToOneTests.test_o2o_primary_key_delete", }, "Cannot use QuerySet.update() when querying across multiple collections on MongoDB.": { + "expressions.tests.BasicExpressionsTests.test_filter_with_join", "queries.tests.Queries4Tests.test_ticket7095", "queries.tests.Queries5Tests.test_ticket9848", + "update.tests.AdvancedTests.test_update_annotated_multi_table_queryset", "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation", "update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc", }, @@ -464,9 +406,6 @@ def django_test_expected_failures(self): "queries.tests.ValuesQuerysetTests.test_named_values_list_without_fields", "select_related.tests.SelectRelatedTests.test_select_related_with_extra", }, - "QuerySet.update() crash: Unrecognized expression '$count'": { - "update.tests.AdvancedTests.test_update_annotated_multi_table_queryset", - }, "Test inspects query for SQL": { "aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned", "aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned",