Skip to content

Commit 6c9620c

Browse files
committed
Support update with expressions.
1 parent 9dd5f08 commit 6c9620c

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

django_mongodb/compiler.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from collections import defaultdict
44

55
from bson import SON
6-
from django.core.exceptions import EmptyResultSet, FullResultSet
7-
from django.db import DatabaseError, IntegrityError, NotSupportedError
8-
from django.db.models import Count, Expression
6+
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
7+
from django.db import IntegrityError, NotSupportedError
8+
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
1010
from django.db.models.expressions import Case, Col, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
@@ -576,9 +576,21 @@ def execute_sql(self, result_type):
576576
related queries are not available.
577577
"""
578578
self.pre_sql_setup()
579-
values = []
579+
values = {}
580580
for field, _, value in self.query.values:
581-
if hasattr(value, "prepare_database_save"):
581+
if hasattr(value, "resolve_expression"):
582+
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
583+
if value.contains_aggregate:
584+
raise FieldError(
585+
"Aggregate functions are not allowed in this query "
586+
f"({field.name}={value})."
587+
)
588+
if value.contains_over_clause:
589+
raise FieldError(
590+
"Window expressions are not allowed in this query "
591+
f"({field.name}={value})."
592+
)
593+
elif hasattr(value, "prepare_database_save"):
582594
if field.remote_field:
583595
value = value.prepare_database_save(field)
584596
else:
@@ -588,42 +600,26 @@ def execute_sql(self, result_type):
588600
f"{field.__class__.__name__}."
589601
)
590602
prepared = field.get_db_prep_save(value, connection=self.connection)
591-
values.append((field, prepared))
603+
if hasattr(value, "as_mql"):
604+
prepared = prepared.as_mql(self, self.connection)
605+
values[field.column] = prepared
606+
try:
607+
criteria = self.build_query().mongo_query
608+
except EmptyResultSet:
609+
return 0
592610
is_empty = not bool(values)
593-
rows = 0 if is_empty else self.update(values)
611+
rows = (
612+
0
613+
if is_empty
614+
else self.collection.update_many(criteria, [{"$set": values}]).matched_count
615+
)
594616
for query in self.query.get_related_updates():
595617
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
596618
if is_empty and aux_rows:
597619
rows = aux_rows
598620
is_empty = False
599621
return rows
600622

601-
def update(self, values):
602-
spec = {}
603-
for field, value in values:
604-
if field.primary_key:
605-
raise DatabaseError("Cannot modify _id.")
606-
if isinstance(value, Expression):
607-
raise NotSupportedError("QuerySet.update() with expression not supported.")
608-
# .update(foo=123) --> {'$set': {'foo': 123}}
609-
spec.setdefault("$set", {})[field.column] = value
610-
return self.execute_update(spec)
611-
612-
@wrap_database_errors
613-
def execute_update(self, update_spec):
614-
try:
615-
criteria = self.build_query().mongo_query
616-
except EmptyResultSet:
617-
return 0
618-
return self.collection.update_many(criteria, update_spec).matched_count
619-
620-
def check_query(self):
621-
super().check_query()
622-
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
623-
raise NotSupportedError(
624-
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
625-
)
626-
627623
def get_where(self):
628624
return self.query.where
629625

0 commit comments

Comments
 (0)