Skip to content

Commit ab0c5e1

Browse files
committed
Support update with expressions.
1 parent 9b2d02f commit ab0c5e1

File tree

1 file changed

+29
-33
lines changed

1 file changed

+29
-33
lines changed

django_mongodb/compiler.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from collections import defaultdict
44

55
from bson import SON
6-
from django.core.exceptions import EmptyResultSet, FullResultSet
7-
from django.db import DatabaseError, IntegrityError, NotSupportedError
8-
from django.db.models import Count, Expression
6+
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
7+
from django.db import IntegrityError, NotSupportedError
8+
from django.db.models import Count
99
from django.db.models.aggregates import Aggregate, Variance
1010
from django.db.models.expressions import Case, Col, Ref, Value, When
1111
from django.db.models.functions.comparison import Coalesce
@@ -579,9 +579,21 @@ def execute_sql(self, result_type):
579579
related queries are not available.
580580
"""
581581
self.pre_sql_setup()
582-
values = []
582+
values = {}
583583
for field, _, value in self.query.values:
584-
if hasattr(value, "prepare_database_save"):
584+
if hasattr(value, "resolve_expression"):
585+
value = value.resolve_expression(self.query, allow_joins=False, for_save=True)
586+
if value.contains_aggregate:
587+
raise FieldError(
588+
"Aggregate functions are not allowed in this query "
589+
f"({field.name}={value})."
590+
)
591+
if value.contains_over_clause:
592+
raise FieldError(
593+
"Window expressions are not allowed in this query "
594+
f"({field.name}={value})."
595+
)
596+
elif hasattr(value, "prepare_database_save"):
585597
if field.remote_field:
586598
value = value.prepare_database_save(field)
587599
else:
@@ -591,42 +603,26 @@ def execute_sql(self, result_type):
591603
f"{field.__class__.__name__}."
592604
)
593605
prepared = field.get_db_prep_save(value, connection=self.connection)
594-
values.append((field, prepared))
606+
if hasattr(value, "as_mql"):
607+
prepared = prepared.as_mql(self, self.connection)
608+
values[field.column] = prepared
609+
try:
610+
criteria = self.build_query().mongo_query
611+
except EmptyResultSet:
612+
return 0
595613
is_empty = not bool(values)
596-
rows = 0 if is_empty else self.update(values)
614+
rows = (
615+
0
616+
if is_empty
617+
else self.collection.update_many(criteria, [{"$set": values}]).matched_count
618+
)
597619
for query in self.query.get_related_updates():
598620
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
599621
if is_empty and aux_rows:
600622
rows = aux_rows
601623
is_empty = False
602624
return rows
603625

604-
def update(self, values):
605-
spec = {}
606-
for field, value in values:
607-
if field.primary_key:
608-
raise DatabaseError("Cannot modify _id.")
609-
if isinstance(value, Expression):
610-
raise NotSupportedError("QuerySet.update() with expression not supported.")
611-
# .update(foo=123) --> {'$set': {'foo': 123}}
612-
spec.setdefault("$set", {})[field.column] = value
613-
return self.execute_update(spec)
614-
615-
@wrap_database_errors
616-
def execute_update(self, update_spec):
617-
try:
618-
criteria = self.build_query().mongo_query
619-
except EmptyResultSet:
620-
return 0
621-
return self.collection.update_many(criteria, update_spec).matched_count
622-
623-
def check_query(self):
624-
super().check_query()
625-
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
626-
raise NotSupportedError(
627-
"Cannot use QuerySet.update() when querying across multiple collections on MongoDB."
628-
)
629-
630626
def get_where(self):
631627
return self.query.where
632628

0 commit comments

Comments
 (0)