Skip to content

POC add expression support to QuerySet.update() #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import defaultdict

from bson import SON
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import DatabaseError, IntegrityError, NotSupportedError
from django.db.models import Count, Expression
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import IntegrityError, NotSupportedError
from django.db.models import Count
from django.db.models.aggregates import Aggregate, Variance
from django.db.models.expressions import Case, Col, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
Expand Down Expand Up @@ -581,7 +581,19 @@ def execute_sql(self, result_type):
self.pre_sql_setup()
values = []
for field, _, value in self.query.values:
if hasattr(value, "prepare_database_save"):
if hasattr(value, "resolve_expression"):
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
if value.contains_aggregate:
raise FieldError(
"Aggregate functions are not allowed in this query "
f"({field.name}={value})."
)
if value.contains_over_clause:
raise FieldError(
"Window expressions are not allowed in this query "
f"({field.name}={value})."
)
elif hasattr(value, "prepare_database_save"):
if field.remote_field:
value = value.prepare_database_save(field)
else:
Expand All @@ -591,42 +603,44 @@ def execute_sql(self, result_type):
f"{field.__class__.__name__}."
)
prepared = field.get_db_prep_save(value, connection=self.connection)
values.append((field, prepared))
if hasattr(value, "as_mql"):
prepared = prepared.as_mql(self, self.connection)
values.append((field.column, prepared))
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
is_empty = not bool(values)
rows = 0 if is_empty else self.update(values)
if is_empty:
rows = 0
else:
base_pipeline = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is the only thing that need focus. Here I just create the $merge pipeline and the affected rows pipeline with a single transaction. I really don't know if this is the way to go.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the right way to go about it. having a transaction ensures that state isn't changed between the lookup and the update.

Copy link
Collaborator Author

@WaVEV WaVEV Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use update many directly as Shane mentioned. I've re-read the docs and what the docs support are this three stages.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, It works 🚀. So we are able to do the update without making two queries (one for count and the other for the update)

{"$match": criteria},
{"$set": dict(values)},
]
count_pipeline = [*base_pipeline, {"$count": "count"}]
pipeline = [
*base_pipeline,
{
"$merge": {
"into": self.collection_name,
"whenMatched": "replace",
"whenNotMatched": "discard",
}
},
]
with self.connection.connection.start_session() as session, session.start_transaction():
result = next(self.collection.aggregate(count_pipeline), {"count": 0})
self.collection.aggregate(pipeline)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what has been the output when you merge:
[*base_pipeline, {$merge..}, {$count...}] Does this not work whatsoever?

rows = result["count"]
# rows = 0 if is_empty else self.update(values)
for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
is_empty = False
return rows

def update(self, values):
spec = {}
for field, value in values:
if field.primary_key:
raise DatabaseError("Cannot modify _id.")
if isinstance(value, Expression):
raise NotSupportedError("QuerySet.update() with expression not supported.")
# .update(foo=123) --> {'$set': {'foo': 123}}
spec.setdefault("$set", {})[field.column] = value
return self.execute_update(spec)

@wrap_database_errors
def execute_update(self, update_spec):
try:
criteria = self.build_query().mongo_query
except EmptyResultSet:
return 0
return self.collection.update_many(criteria, update_spec).matched_count

def check_query(self):
super().check_query()
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
raise NotSupportedError(
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
)

def get_where(self):
return self.query.where

Expand Down
69 changes: 11 additions & 58 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,68 +105,21 @@ def django_test_expected_failures(self):
"m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_reverse_proxied",
},
"QuerySet.update() with expression not supported.": {
"annotations.tests.AliasTests.test_update_with_alias",
"annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation",
"db_functions.comparison.test_least.LeastTests.test_update",
"db_functions.comparison.test_greatest.GreatestTests.test_update",
"db_functions.text.test_left.LeftTests.test_basic",
"db_functions.text.test_lower.LowerTests.test_basic",
"db_functions.text.test_replace.ReplaceTests.test_update",
"db_functions.text.test_substr.SubstrTests.test_basic",
"db_functions.text.test_upper.UpperTests.test_basic",
"expressions.tests.BasicExpressionsTests.test_arithmetic",
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"expressions.tests.BasicExpressionsTests.test_object_update",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
"expressions.tests.BasicExpressionsTests.test_order_of_operations",
"expressions.tests.BasicExpressionsTests.test_parenthesis_priority",
"expressions.tests.BasicExpressionsTests.test_update",
"expressions.tests.BasicExpressionsTests.test_update_with_fk",
"expressions.tests.BasicExpressionsTests.test_update_with_none",
"expressions.tests.ExpressionsNumericTests.test_decimal_expression",
"expressions.tests.ExpressionsNumericTests.test_increment_value",
"expressions.tests.FTimeDeltaTests.test_delta_update",
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
"expressions.tests.ValueTests.test_update_TimeField_using_Value",
"expressions.tests.ValueTests.test_update_UUIDField_using_Value",
"expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example",
"expressions_case.tests.CaseExpressionTests.test_update",
"expressions_case.tests.CaseExpressionTests.test_update_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_binary",
"expressions_case.tests.CaseExpressionTests.test_update_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_date",
"expressions_case.tests.CaseExpressionTests.test_update_date_time",
"expressions_case.tests.CaseExpressionTests.test_update_decimal",
"expressions_case.tests.CaseExpressionTests.test_update_duration",
"expressions_case.tests.CaseExpressionTests.test_update_email",
"expressions_case.tests.CaseExpressionTests.test_update_file",
"expressions_case.tests.CaseExpressionTests.test_update_file_path",
"expressions_case.tests.CaseExpressionTests.test_update_fk",
"expressions_case.tests.CaseExpressionTests.test_update_float",
"expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address",
"expressions_case.tests.CaseExpressionTests.test_update_image",
"expressions_case.tests.CaseExpressionTests.test_update_null_boolean",
"expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_integer",
"expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_slug",
"expressions_case.tests.CaseExpressionTests.test_update_small_integer",
"expressions_case.tests.CaseExpressionTests.test_update_string",
"expressions_case.tests.CaseExpressionTests.test_update_text",
"expressions_case.tests.CaseExpressionTests.test_update_time",
"expressions_case.tests.CaseExpressionTests.test_update_url",
"expressions_case.tests.CaseExpressionTests.test_update_uuid",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition",
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value",
"expressions_case.tests.CaseExpressionTests.test_update_without_default",
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
"queries.test_bulk_update.BulkUpdateNoteTests",
"queries.test_bulk_update.BulkUpdateTests",
"timezones.tests.NewDatabaseTests.test_update_with_timedelta",
"update.tests.AdvancedTests.test_update_annotated_queryset",
"update.tests.AdvancedTests.test_update_negated_f",
"update.tests.AdvancedTests.test_update_negated_f_conditional_annotation",
"update.tests.AdvancedTests.test_update_transformed_field",
"queries.test_bulk_update.BulkUpdateTests.test_database_routing",
"delete_regress.tests.SetQueryCountTests.test_set_querycount",
"delete.tests.FastDeleteTests.test_fast_delete_qs",
"delete.tests.FastDeleteTests.test_fast_delete_instance_set_pk_none",
"delete.tests.FastDeleteTests.test_fast_delete_large_batch",
"delete.tests.FastDeleteTests.test_fast_delete_fk",
"delete.tests.DeletionTests.test_cannot_defer_constraint_checks",
"basic.tests.SelectOnSaveTests.test_select_on_save",
},
"AutoField not supported.": {
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",
Expand Down