Skip to content

POC add expression support to QuerySet.update() #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import defaultdict

from bson import SON
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import DatabaseError, IntegrityError, NotSupportedError
from django.db.models import Count, Expression
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import IntegrityError, NotSupportedError
from django.db.models import Count
from django.db.models.aggregates import Aggregate, Variance
from django.db.models.expressions import Case, Col, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
Expand Down Expand Up @@ -581,7 +581,19 @@ def execute_sql(self, result_type):
self.pre_sql_setup()
values = []
for field, _, value in self.query.values:
if hasattr(value, "prepare_database_save"):
if hasattr(value, "resolve_expression"):
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
if value.contains_aggregate:
raise FieldError(
"Aggregate functions are not allowed in this query "
f"({field.name}={value})."
)
if value.contains_over_clause:
raise FieldError(
"Window expressions are not allowed in this query "
f"({field.name}={value})."
)
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
Expand All @@ -591,42 +603,26 @@ def execute_sql(self, result_type):
f"{field.__class__.__name__}."
)
prepared = field.get_db_prep_save(value, connection=self.connection)
values.append((field, prepared))
if hasattr(value, "as_mql"):
prepared = prepared.as_mql(self, self.connection)
values.append((field.column, prepared))
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
is_empty = not bool(values)
rows = 0 if is_empty else self.update(values)
if is_empty:
rows = 0
else:
rows = self.collection.update_many(criteria, [{"$set": dict(values)}]).matched_count

for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows

def update(self, values):
spec = {}
for field, value in values:
if field.primary_key:
raise DatabaseError("Cannot modify _id.")
if isinstance(value, Expression):
raise NotSupportedError("QuerySet.update() with expression not supported.")
# .update(foo=123) --> {'$set': {'foo': 123}}
spec.setdefault("$set", {})[field.column] = value
return self.execute_update(spec)

@wrap_database_errors
def execute_update(self, update_spec):
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
return self.collection.update_many(criteria, update_spec).matched_count

def check_query(self):
super().check_query()
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
raise NotSupportedError(
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
)

def get_where(self):
return self.query.where

Expand Down
69 changes: 11 additions & 58 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,68 +105,21 @@ def django_test_expected_failures(self):
"m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_reverse_proxied",
},
"QuerySet.update() with expression not supported.": {
"annotations.tests.AliasTests.test_update_with_alias",
"annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation",
"db_functions.comparison.test_least.LeastTests.test_update",
"db_functions.comparison.test_greatest.GreatestTests.test_update",
"db_functions.text.test_left.LeftTests.test_basic",
"db_functions.text.test_lower.LowerTests.test_basic",
"db_functions.text.test_replace.ReplaceTests.test_update",
"db_functions.text.test_substr.SubstrTests.test_basic",
"db_functions.text.test_upper.UpperTests.test_basic",
"expressions.tests.BasicExpressionsTests.test_arithmetic",
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"expressions.tests.BasicExpressionsTests.test_object_update",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
"expressions.tests.BasicExpressionsTests.test_order_of_operations",
"expressions.tests.BasicExpressionsTests.test_parenthesis_priority",
"expressions.tests.BasicExpressionsTests.test_update",
"expressions.tests.BasicExpressionsTests.test_update_with_fk",
"expressions.tests.BasicExpressionsTests.test_update_with_none",
"expressions.tests.ExpressionsNumericTests.test_decimal_expression",
"expressions.tests.ExpressionsNumericTests.test_increment_value",
"expressions.tests.FTimeDeltaTests.test_delta_update",
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
"expressions.tests.ValueTests.test_update_TimeField_using_Value",
"expressions.tests.ValueTests.test_update_UUIDField_using_Value",
"expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example",
"expressions_case.tests.CaseExpressionTests.test_update",
"expressions_case.tests.CaseExpressionTests.test_update_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_binary",
"expressions_case.tests.CaseExpressionTests.test_update_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_date",
"expressions_case.tests.CaseExpressionTests.test_update_date_time",
"expressions_case.tests.CaseExpressionTests.test_update_decimal",
"expressions_case.tests.CaseExpressionTests.test_update_duration",
"expressions_case.tests.CaseExpressionTests.test_update_email",
"expressions_case.tests.CaseExpressionTests.test_update_file",
"expressions_case.tests.CaseExpressionTests.test_update_file_path",
"expressions_case.tests.CaseExpressionTests.test_update_fk",
"expressions_case.tests.CaseExpressionTests.test_update_float",
"expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address",
"expressions_case.tests.CaseExpressionTests.test_update_image",
"expressions_case.tests.CaseExpressionTests.test_update_null_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_slug",
"expressions_case.tests.CaseExpressionTests.test_update_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_string",
"expressions_case.tests.CaseExpressionTests.test_update_text",
"expressions_case.tests.CaseExpressionTests.test_update_time",
"expressions_case.tests.CaseExpressionTests.test_update_url",
"expressions_case.tests.CaseExpressionTests.test_update_uuid",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value",
"expressions_case.tests.CaseExpressionTests.test_update_without_default",
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
"queries.test_bulk_update.BulkUpdateNoteTests",
"queries.test_bulk_update.BulkUpdateTests",
"timezones.tests.NewDatabaseTests.test_update_with_timedelta",
"update.tests.AdvancedTests.test_update_annotated_queryset",
"update.tests.AdvancedTests.test_update_negated_f",
"update.tests.AdvancedTests.test_update_negated_f_conditional_annotation",
"update.tests.AdvancedTests.test_update_transformed_field",
"queries.test_bulk_update.BulkUpdateTests.test_database_routing",
"delete_regress.tests.SetQueryCountTests.test_set_querycount",
"delete.tests.FastDeleteTests.test_fast_delete_qs",
"delete.tests.FastDeleteTests.test_fast_delete_instance_set_pk_none",
"delete.tests.FastDeleteTests.test_fast_delete_large_batch",
"delete.tests.FastDeleteTests.test_fast_delete_fk",
"delete.tests.DeletionTests.test_cannot_defer_constraint_checks",
"basic.tests.SelectOnSaveTests.test_select_on_save",
},
"AutoField not supported.": {
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",
Expand Down