Skip to content

Commit cc29068

Browse files
committed
POC support update with expressions.
1 parent 143079b commit cc29068

File tree

2 files changed

+57
-90
lines changed

2 files changed

+57
-90
lines changed

django_mongodb/compiler.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from collections import defaultdict
44

55
from bson import SON
6-
from django.core.exceptions import EmptyResultSet, FullResultSet
7-
from django.db import DatabaseError, IntegrityError, NotSupportedError
8-
from django.db.models import Count, Expression
6+
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
7+
from django.db import IntegrityError, NotSupportedError
8+
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
1010
from django.db.models.expressions import Case, Col, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
@@ -581,7 +581,19 @@ def execute_sql(self, result_type):
581581
self.pre_sql_setup()
582582
values = []
583583
for field, _, value in self.query.values:
584-
if hasattr(value, "prepare_database_save"):
584+
if hasattr(value, "resolve_expression"):
585+
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
586+
if value.contains_aggregate:
587+
raise FieldError(
588+
"Aggregate functions are not allowed in this query "
589+
f"({field.name}={value})."
590+
)
591+
if value.contains_over_clause:
592+
raise FieldError(
593+
"Window expressions are not allowed in this query "
594+
f"({field.name}={value})."
595+
)
596+
elif hasattr(value, "prepare_database_save"):
585597
if field.remote_field:
586598
value = value.prepare_database_save(field)
587599
else:
@@ -591,42 +603,44 @@ def execute_sql(self, result_type):
591603
f"{field.__class__.__name__}."
592604
)
593605
prepared = field.get_db_prep_save(value, connection=self.connection)
594-
values.append((field, prepared))
606+
if hasattr(value, "as_mql"):
607+
prepared = prepared.as_mql(self, self.connection)
608+
values.append((field.column, prepared))
609+
try:
610+
criteria = self.build_query().mongo_query
611+
except EmptyResultSet:
612+
return 0
595613
is_empty = not bool(values)
596-
rows = 0 if is_empty else self.update(values)
614+
if is_empty:
615+
rows = 0
616+
else:
617+
base_pipeline = [
618+
{"$match": criteria},
619+
{"$set": dict(values)},
620+
]
621+
count_pipeline = [*base_pipeline, {"$count": "count"}]
622+
pipeline = [
623+
*base_pipeline,
624+
{
625+
"$merge": {
626+
"into": self.collection_name,
627+
"whenMatched": "replace",
628+
"whenNotMatched": "discard",
629+
}
630+
},
631+
]
632+
with self.connection.connection.start_session() as session, session.start_transaction():
633+
result = next(self.collection.aggregate(count_pipeline), {"count": 0})
634+
self.collection.aggregate(pipeline)
635+
rows = result["count"]
636+
# rows = 0 if is_empty else self.update(values)
597637
for query in self.query.get_related_updates():
598638
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
599639
if is_empty and aux_rows:
600640
rows = aux_rows
601641
is_empty = False
602642
return rows
603643

604-
def update(self, values):
605-
spec = {}
606-
for field, value in values:
607-
if field.primary_key:
608-
raise DatabaseError("Cannot modify _id.")
609-
if isinstance(value, Expression):
610-
raise NotSupportedError("QuerySet.update() with expression not supported.")
611-
# .update(foo=123) --> {'$set': {'foo': 123}}
612-
spec.setdefault("$set", {})[field.column] = value
613-
return self.execute_update(spec)
614-
615-
@wrap_database_errors
616-
def execute_update(self, update_spec):
617-
try:
618-
criteria = self.build_query().mongo_query
619-
except EmptyResultSet:
620-
return 0
621-
return self.collection.update_many(criteria, update_spec).matched_count
622-
623-
def check_query(self):
624-
super().check_query()
625-
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
626-
raise NotSupportedError(
627-
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
628-
)
629-
630644
def get_where(self):
631645
return self.query.where
632646

django_mongodb/features.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -105,68 +105,21 @@ def django_test_expected_failures(self):
105105
"m2m_through_regress.test_multitable.MultiTableTests.test_m2m_prefetch_reverse_proxied",
106106
},
107107
"QuerySet.update() with expression not supported.": {
108-
"annotations.tests.AliasTests.test_update_with_alias",
109-
"annotations.tests.NonAggregateAnnotationTestCase.test_update_with_annotation",
110-
"db_functions.comparison.test_least.LeastTests.test_update",
111-
"db_functions.comparison.test_greatest.GreatestTests.test_update",
112-
"db_functions.text.test_left.LeftTests.test_basic",
113-
"db_functions.text.test_lower.LowerTests.test_basic",
114-
"db_functions.text.test_replace.ReplaceTests.test_update",
115-
"db_functions.text.test_substr.SubstrTests.test_basic",
116-
"db_functions.text.test_upper.UpperTests.test_basic",
117-
"expressions.tests.BasicExpressionsTests.test_arithmetic",
118-
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
119-
"expressions.tests.BasicExpressionsTests.test_object_update",
120-
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
121-
"expressions.tests.BasicExpressionsTests.test_order_of_operations",
122-
"expressions.tests.BasicExpressionsTests.test_parenthesis_priority",
123-
"expressions.tests.BasicExpressionsTests.test_update",
124-
"expressions.tests.BasicExpressionsTests.test_update_with_fk",
125-
"expressions.tests.BasicExpressionsTests.test_update_with_none",
126-
"expressions.tests.ExpressionsNumericTests.test_decimal_expression",
127-
"expressions.tests.ExpressionsNumericTests.test_increment_value",
128-
"expressions.tests.FTimeDeltaTests.test_delta_update",
129-
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
130-
"expressions.tests.ValueTests.test_update_TimeField_using_Value",
131108
"expressions.tests.ValueTests.test_update_UUIDField_using_Value",
132-
"expressions_case.tests.CaseDocumentationExamples.test_conditional_update_example",
133-
"expressions_case.tests.CaseExpressionTests.test_update",
134-
"expressions_case.tests.CaseExpressionTests.test_update_big_integer",
135-
"expressions_case.tests.CaseExpressionTests.test_update_binary",
136-
"expressions_case.tests.CaseExpressionTests.test_update_boolean",
137-
"expressions_case.tests.CaseExpressionTests.test_update_date",
138-
"expressions_case.tests.CaseExpressionTests.test_update_date_time",
139-
"expressions_case.tests.CaseExpressionTests.test_update_decimal",
140-
"expressions_case.tests.CaseExpressionTests.test_update_duration",
141-
"expressions_case.tests.CaseExpressionTests.test_update_email",
142-
"expressions_case.tests.CaseExpressionTests.test_update_file",
143-
"expressions_case.tests.CaseExpressionTests.test_update_file_path",
144-
"expressions_case.tests.CaseExpressionTests.test_update_fk",
145-
"expressions_case.tests.CaseExpressionTests.test_update_float",
146-
"expressions_case.tests.CaseExpressionTests.test_update_generic_ip_address",
147-
"expressions_case.tests.CaseExpressionTests.test_update_image",
148-
"expressions_case.tests.CaseExpressionTests.test_update_null_boolean",
149-
"expressions_case.tests.CaseExpressionTests.test_update_positive_big_integer",
150-
"expressions_case.tests.CaseExpressionTests.test_update_positive_integer",
151-
"expressions_case.tests.CaseExpressionTests.test_update_positive_small_integer",
152-
"expressions_case.tests.CaseExpressionTests.test_update_slug",
153-
"expressions_case.tests.CaseExpressionTests.test_update_small_integer",
154-
"expressions_case.tests.CaseExpressionTests.test_update_string",
155-
"expressions_case.tests.CaseExpressionTests.test_update_text",
156-
"expressions_case.tests.CaseExpressionTests.test_update_time",
157-
"expressions_case.tests.CaseExpressionTests.test_update_url",
158109
"expressions_case.tests.CaseExpressionTests.test_update_uuid",
159-
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_condition",
160-
"expressions_case.tests.CaseExpressionTests.test_update_with_expression_as_value",
161-
"expressions_case.tests.CaseExpressionTests.test_update_without_default",
110+
"expressions.tests.FTimeDeltaTests.test_negative_timedelta_update",
111+
"expressions.tests.BasicExpressionsTests.test_filter_with_join",
112+
"expressions.tests.BasicExpressionsTests.test_object_update_unsaved_objects",
162113
"model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values",
163114
"queries.test_bulk_update.BulkUpdateNoteTests",
164-
"queries.test_bulk_update.BulkUpdateTests",
165-
"timezones.tests.NewDatabaseTests.test_update_with_timedelta",
166-
"update.tests.AdvancedTests.test_update_annotated_queryset",
167-
"update.tests.AdvancedTests.test_update_negated_f",
168-
"update.tests.AdvancedTests.test_update_negated_f_conditional_annotation",
169-
"update.tests.AdvancedTests.test_update_transformed_field",
115+
"queries.test_bulk_update.BulkUpdateTests.test_database_routing",
116+
"delete_regress.tests.SetQueryCountTests.test_set_querycount",
117+
"delete.tests.FastDeleteTests.test_fast_delete_qs",
118+
"delete.tests.FastDeleteTests.test_fast_delete_instance_set_pk_none",
119+
"delete.tests.FastDeleteTests.test_fast_delete_large_batch",
120+
"delete.tests.FastDeleteTests.test_fast_delete_fk",
121+
"delete.tests.DeletionTests.test_cannot_defer_constraint_checks",
122+
"basic.tests.SelectOnSaveTests.test_select_on_save",
170123
},
171124
"AutoField not supported.": {
172125
"bulk_create.tests.BulkCreateTests.test_bulk_insert_nullable_fields",

0 commit comments

Comments
 (0)