Skip to content

add expression support to QuerySet.update() #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import defaultdict

from bson import SON
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import DatabaseError, IntegrityError, NotSupportedError
from django.db.models import Count, Expression
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import IntegrityError, NotSupportedError
from django.db.models import Count
from django.db.models.aggregates import Aggregate, Variance
from django.db.models.expressions import Case, Col, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
Expand Down Expand Up @@ -576,9 +576,21 @@ def execute_sql(self, result_type):
related queries are not available.
"""
self.pre_sql_setup()
values = []
values = {}
for field, _, value in self.query.values:
if hasattr(value, "prepare_database_save"):
if hasattr(value, "resolve_expression"):
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
if value.contains_aggregate:
raise FieldError(
"Aggregate functions are not allowed in this query "
f"({field.name}={value})."
)
if value.contains_over_clause:
raise FieldError(
"Window expressions are not allowed in this query "
f"({field.name}={value})."
)
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
Expand All @@ -588,34 +600,25 @@ def execute_sql(self, result_type):
f"{field.__class__.__name__}."
)
prepared = field.get_db_prep_save(value, connection=self.connection)
values.append((field, prepared))
if hasattr(value, "as_mql"):
prepared = prepared.as_mql(self, self.connection)
values[field.column] = prepared
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
is_empty = not bool(values)
rows = 0 if is_empty else self.update(values)
rows = 0 if is_empty else self.update(criteria, [{"$set": values}])
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows

def update(self, values):
spec = {}
for field, value in values:
if field.primary_key:
raise DatabaseError("Cannot modify _id.")
if isinstance(value, Expression):
raise NotSupportedError("QuerySet.update() with expression not supported.")
# .update(foo=123) --> {'$set': {'foo': 123}}
spec.setdefault("$set", {})[field.column] = value
return self.execute_update(spec)

@wrap_database_errors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lost this decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who needs the safe net?! my bad. 😬 deleted by mistake.

def execute_update(self, update_spec):
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
return self.collection.update_many(criteria, update_spec).matched_count
def update(self, criteria, pipeline):
return self.collection.update_many(criteria, pipeline).matched_count

def check_query(self):
super().check_query()
Expand Down
3 changes: 3 additions & 0 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from decimal import Decimal
from uuid import UUID

from bson import Decimal128
from django.core.exceptions import EmptyResultSet, FullResultSet
Expand Down Expand Up @@ -110,6 +111,8 @@ def value(self, compiler, connection): # noqa: ARG001
elif isinstance(value, datetime.timedelta):
# DurationField stores milliseconds rather than microseconds.
value /= datetime.timedelta(milliseconds=1)
elif isinstance(value, UUID):
value = value.hex
return {"$literal": value}


Expand Down
75 changes: 7 additions & 68 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def django_test_expected_failures(self):
"expressions.tests.BasicExpressionsTests.test_new_object_save",
"expressions.tests.BasicExpressionsTests.test_object_create_with_aggregate",
"expressions.tests.BasicExpressionsTests.test_object_create_with_f_expression_in_subquery",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
# PI()
"db_functions.math.test_round.RoundTests.test_decimal_with_precision",
"db_functions.math.test_round.RoundTests.test_float_with_precision",
Expand All @@ -124,70 +125,6 @@ def django_test_expected_failures(self):
"many_to_many.tests.ManyToManyTests.test_set_after_prefetch",
"model_forms.tests.OtherModelFormTests.test_prefetch_related_queryset",
},
"QuerySet.update() with expression not supported.": {
"annotations.tests.AliasTests.test_update_with_alias",
"annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation",
"db_functions.comparison.test_least.LeastTests.test_update",
"db_functions.comparison.test_greatest.GreatestTests.test_update",
"db_functions.text.test_left.LeftTests.test_basic",
"db_functions.text.test_lower.LowerTests.test_basic",
"db_functions.text.test_replace.ReplaceTests.test_update",
"db_functions.text.test_substr.SubstrTests.test_basic",
"db_functions.text.test_upper.UpperTests.test_basic",
"expressions.tests.BasicExpressionsTests.test_arithmetic",
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"expressions.tests.BasicExpressionsTests.test_object_update",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
"expressions.tests.BasicExpressionsTests.test_order_of_operations",
"expressions.tests.BasicExpressionsTests.test_parenthesis_priority",
"expressions.tests.BasicExpressionsTests.test_update",
"expressions.tests.BasicExpressionsTests.test_update_with_fk",
"expressions.tests.BasicExpressionsTests.test_update_with_none",
"expressions.tests.ExpressionsNumericTests.test_decimal_expression",
"expressions.tests.ExpressionsNumericTests.test_increment_value",
"expressions.tests.FTimeDeltaTests.test_delta_update",
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
"expressions.tests.ValueTests.test_update_TimeField_using_Value",
"expressions.tests.ValueTests.test_update_UUIDField_using_Value",
"expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example",
"expressions_case.tests.CaseExpressionTests.test_update",
"expressions_case.tests.CaseExpressionTests.test_update_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_binary",
"expressions_case.tests.CaseExpressionTests.test_update_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_date",
"expressions_case.tests.CaseExpressionTests.test_update_date_time",
"expressions_case.tests.CaseExpressionTests.test_update_decimal",
"expressions_case.tests.CaseExpressionTests.test_update_duration",
"expressions_case.tests.CaseExpressionTests.test_update_email",
"expressions_case.tests.CaseExpressionTests.test_update_file",
"expressions_case.tests.CaseExpressionTests.test_update_file_path",
"expressions_case.tests.CaseExpressionTests.test_update_fk",
"expressions_case.tests.CaseExpressionTests.test_update_float",
"expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address",
"expressions_case.tests.CaseExpressionTests.test_update_image",
"expressions_case.tests.CaseExpressionTests.test_update_null_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_slug",
"expressions_case.tests.CaseExpressionTests.test_update_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_string",
"expressions_case.tests.CaseExpressionTests.test_update_text",
"expressions_case.tests.CaseExpressionTests.test_update_time",
"expressions_case.tests.CaseExpressionTests.test_update_url",
"expressions_case.tests.CaseExpressionTests.test_update_uuid",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value",
"expressions_case.tests.CaseExpressionTests.test_update_without_default",
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
"queries.test_bulk_update.BulkUpdateNoteTests",
"queries.test_bulk_update.BulkUpdateTests",
"timezones.tests.NewDatabaseTests.test_update_with_timedelta",
"update.tests.AdvancedTests.test_update_annotated_queryset",
"update.tests.AdvancedTests.test_update_negated_f",
"update.tests.AdvancedTests.test_update_negated_f_conditional_annotation",
"update.tests.AdvancedTests.test_update_transformed_field",
},
"AutoField not supported.": {
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",
"lookup.tests.LookupTests.test_filter_by_reverse_related_field_transform",
Expand Down Expand Up @@ -216,6 +153,9 @@ def django_test_expected_failures(self):
"one_to_one.tests.OneToOneTests.test_multiple_o2o",
"queries.test_bulk_update.BulkUpdateTests.test_database_routing_batch_atomicity",
},
"MongoDB does not enforce PositiveIntegerField constraint.": {
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
},
"Test assumes integer primary key.": {
"db_functions.comparison.test_cast.CastTests.test_cast_to_integer_foreign_key",
"model_fields.test_foreignkey.ForeignKeyTests.test_to_python",
Expand Down Expand Up @@ -333,10 +273,10 @@ def django_test_expected_failures(self):
"aggregation_regress.tests.AggregationTests.test_more_more_more5",
"aggregation_regress.tests.AggregationTests.test_negated_aggregation",
"db_functions.comparison.test_coalesce.CoalesceTests.test_empty_queryset",
"expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause",
"expressions.tests.FTimeDeltaTests.test_date_subquery_subtraction",
"expressions.tests.FTimeDeltaTests.test_datetime_subquery_subtraction",
"expressions.tests.FTimeDeltaTests.test_time_subquery_subtraction",
"expressions_case.tests.CaseExpressionTests.test_annotate_with_in_clause",
"expressions_case.tests.CaseExpressionTests.test_in_subquery",
"lookup.tests.LookupTests.test_exact_query_rhs_with_selected_columns",
"lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one",
Expand Down Expand Up @@ -386,8 +326,10 @@ def django_test_expected_failures(self):
"one_to_one.tests.OneToOneTests.test_o2o_primary_key_delete",
},
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB.": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should keep this error message. Raising "Using a QuerySet in annotate() is not supported on MongoDB" isn't helpful for the user since this isn't what they're doing but rather what Django is doing behind the scenes (kind of... really this error message needs to be changed since Query is used in more than just annotate()).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. yes, But the other test are going to be recategorized also?

"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"queries.tests.Queries4Tests.test_ticket7095",
"queries.tests.Queries5Tests.test_ticket9848",
"update.tests.AdvancedTests.test_update_annotated_multi_table_queryset",
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc",
},
Expand Down Expand Up @@ -464,9 +406,6 @@ def django_test_expected_failures(self):
"queries.tests.ValuesQuerysetTests.test_named_values_list_without_fields",
"select_related.tests.SelectRelatedTests.test_select_related_with_extra",
},
"QuerySet.update() crash: Unrecognized expression '$count'": {
"update.tests.AdvancedTests.test_update_annotated_multi_table_queryset",
},
"Test inspects query for SQL": {
"aggregation.tests.AggregateAnnotationPruningTests.test_non_aggregate_annotation_pruned",
"aggregation.tests.AggregateAnnotationPruningTests.test_unreferenced_aggregate_annotation_pruned",
Expand Down