Skip to content

Commit e797d85

Browse files
committed
fix QuerySet.update() with multi-table inheritance
1 parent c05c3a2 commit e797d85

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

django_mongodb/compiler.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from itertools import chain
2+
13
from django.core.exceptions import EmptyResultSet, FullResultSet
24
from django.db import DatabaseError, IntegrityError, NotSupportedError
35
from django.db.models import Count, Expression
46
from django.db.models.aggregates import Aggregate
57
from django.db.models.expressions import OrderBy
68
from django.db.models.sql import compiler
7-
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR
9+
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE
810
from django.utils.functional import cached_property
911

1012
from .base import Cursor
@@ -33,11 +35,18 @@ def execute_sql(
3335
except EmptyResultSet:
3436
return iter([]) if result_type == MULTI else None
3537

36-
return (
37-
(self._make_result(row, columns) for row in query.fetch())
38-
if result_type == MULTI
39-
else self._make_result(next(query.fetch()), columns)
40-
)
38+
cursor = query.get_cursor()
39+
if result_type == SINGLE:
40+
return self._make_result(next(cursor), columns)
41+
# result_type is MULTI
42+
cursor.batch_size(chunk_size)
43+
result = self.cursor_iter(cursor, chunk_size, columns)
44+
if not chunked_fetch:
45+
# If using non-chunked reads, return the same data structure as
46+
# normally, but ensure it is all read into memory before going
47+
# any further.
48+
return list(result)
49+
return result
4150

4251
def results_iter(
4352
self,
@@ -49,14 +58,23 @@ def results_iter(
4958
"""
5059
Return an iterator over the results from executing query given
5160
to this compiler. Called by QuerySet methods.
61+
62+
This method is copied from the superclass with one modification: the
63+
`if tuple_expected` block is deindented so that the result of
64+
_make_result() (a list) is cast to tuple as needed. For SQL database
65+
drivers, results come from cursor.fetchmany() and are tuples, so the
66+
cast is only needed there when apply_converters() casts the tuple to a
67+
list. This customized method could be removed if _make_result() cast
68+
its return value to a tuple, but that would be more expensive since
69+
it's not always needed.
5270
"""
5371
if results is None:
5472
# QuerySet.values() or values_list()
5573
results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
5674

5775
fields = [s[0] for s in self.select[0 : self.col_count]]
5876
converters = self.get_converters(fields)
59-
rows = results
77+
rows = chain.from_iterable(results)
6078
if converters:
6179
rows = self.apply_converters(rows, converters)
6280
if tuple_expected:
@@ -86,6 +104,15 @@ def _make_result(self, entity, columns):
86104
result.append(obj.get(name))
87105
return result
88106

107+
def cursor_iter(self, cursor, chunk_size, columns):
108+
"""Generator to yield chunks from cursor."""
109+
chunk = []
110+
for i, row in enumerate(cursor):
111+
if i % chunk_size == 0 and i > 0:
112+
yield chunk
113+
chunk.append(self._make_result(row, columns))
114+
yield chunk
115+
89116
def check_query(self):
90117
"""Check if the current query is supported by the database."""
91118
if self.query.is_empty():
@@ -293,8 +320,14 @@ def check_query(self):
293320
)
294321

295322

296-
class SQLUpdateCompiler(SQLCompiler):
323+
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
297324
def execute_sql(self, result_type):
325+
"""
326+
Execute the specified update. Return the number of rows affected by
327+
the primary update query. The "primary update query" is the first
328+
non-empty query that is executed. Row counts for any subsequent,
329+
related queries are not available.
330+
"""
298331
self.pre_sql_setup()
299332
values = []
300333
for field, _, value in self.query.values:
@@ -309,7 +342,13 @@ def execute_sql(self, result_type):
309342
)
310343
prepared = field.get_db_prep_save(value, connection=self.connection)
311344
values.append((field, prepared))
312-
return self.update(values)
345+
rows = 0 if (is_empty := not bool(values)) else self.update(values)
346+
for query in self.query.get_related_updates():
347+
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
348+
if is_empty and aux_rows:
349+
rows = aux_rows
350+
is_empty = False
351+
return rows
313352

314353
def update(self, values):
315354
spec = {}

django_mongodb/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5151
"update.tests.AdvancedTests.test_update_ordered_by_inline_m2m_annotation",
5252
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
5353
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc",
54-
# pymongo: ValueError: update cannot be empty
55-
"update.tests.SimpleTest.test_empty_update_with_inheritance",
56-
"update.tests.SimpleTest.test_nonempty_update_with_inheritance",
5754
# Pattern lookups that use regexMatch don't work on JSONField:
5855
# Unsupported conversion from array to string in $convert
5956
"model_fields.test_jsonfield.TestQuerying.test_icontains",

django_mongodb/query.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ def __init__(self, compiler, columns):
5353
def __repr__(self):
5454
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
5555

56-
def fetch(self):
57-
"""Return an iterator over the query results."""
58-
yield from self.get_cursor()
59-
6056
@wrap_database_errors
6157
def count(self, limit=None, skip=None):
6258
"""

0 commit comments

Comments
 (0)