Skip to content

Commit 602716f

Browse files
committed
fix QuerySet.update() with multi-table inheritance
1 parent c05c3a2 commit 602716f

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

django_mongodb/compiler.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from itertools import chain
2+
13
from django.core.exceptions import EmptyResultSet, FullResultSet
24
from django.db import DatabaseError, IntegrityError, NotSupportedError
35
from django.db.models import Count, Expression
46
from django.db.models.aggregates import Aggregate
57
from django.db.models.expressions import OrderBy
68
from django.db.models.sql import compiler
7-
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR
9+
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE
810
from django.utils.functional import cached_property
911

1012
from .base import Cursor
@@ -33,11 +35,23 @@ def execute_sql(
3335
except EmptyResultSet:
3436
return iter([]) if result_type == MULTI else None
3537

36-
return (
37-
(self._make_result(row, columns) for row in query.fetch())
38-
if result_type == MULTI
39-
else self._make_result(next(query.fetch()), columns)
40-
)
38+
cursor = query.get_cursor()
39+
40+
if result_type == SINGLE:
41+
return self._make_result(next(cursor))
42+
43+
if result_type == MULTI:
44+
result = self.cursor_iter(cursor, chunk_size, columns)
45+
46+
if not chunked_fetch or not self.connection.features.can_use_chunked_reads:
47+
# If using non-chunked reads, return the same data structure as
48+
# normally, but ensure it is all read into memory before going
49+
# any further. Use chunked_fetch if requested,
50+
# unless the database doesn't support it.
51+
return list(result)
52+
return result
53+
54+
raise ValueError("shouldn't get here")
4155

4256
def results_iter(
4357
self,
@@ -49,14 +63,16 @@ def results_iter(
4963
"""
5064
Return an iterator over the results from executing query given
5165
to this compiler. Called by QuerySet methods.
66+
67+
XXX: Can this method be removed?
5268
"""
5369
if results is None:
5470
# QuerySet.values() or values_list()
5571
results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
5672

5773
fields = [s[0] for s in self.select[0 : self.col_count]]
5874
converters = self.get_converters(fields)
59-
rows = results
75+
rows = chain.from_iterable(results)
6076
if converters:
6177
rows = self.apply_converters(rows, converters)
6278
if tuple_expected:
@@ -86,6 +102,15 @@ def _make_result(self, entity, columns):
86102
result.append(obj.get(name))
87103
return result
88104

105+
def cursor_iter(self, cursor, chunk_size, columns):
106+
"""Generator to yield chunks from cursor."""
107+
chunk = []
108+
for i, row in enumerate(cursor):
109+
if i % chunk_size == 0 and i > 0:
110+
yield chunk
111+
chunk.append(self._make_result(row, columns))
112+
yield chunk
113+
89114
def check_query(self):
90115
"""Check if the current query is supported by the database."""
91116
if self.query.is_empty():
@@ -293,8 +318,14 @@ def check_query(self):
293318
)
294319

295320

296-
class SQLUpdateCompiler(SQLCompiler):
321+
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
297322
def execute_sql(self, result_type):
323+
"""
324+
Execute the specified update. Return the number of rows affected by
325+
the primary update query. The "primary update query" is the first
326+
non-empty query that is executed. Row counts for any subsequent,
327+
related queries are not available.
328+
"""
298329
self.pre_sql_setup()
299330
values = []
300331
for field, _, value in self.query.values:
@@ -309,7 +340,17 @@ def execute_sql(self, result_type):
309340
)
310341
prepared = field.get_db_prep_save(value, connection=self.connection)
311342
values.append((field, prepared))
312-
return self.update(values)
343+
# if is_empty := not bool(values):
344+
# rows = 0
345+
# else:
346+
# rows = self.update(values)
347+
rows = 0 if (is_empty := not bool(values)) else self.update(values)
348+
for query in self.query.get_related_updates():
349+
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
350+
if is_empty and aux_rows:
351+
rows = aux_rows
352+
is_empty = False
353+
return rows
313354

314355
def update(self, values):
315356
spec = {}

django_mongodb/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5151
"update.tests.AdvancedTests.test_update_ordered_by_inline_m2m_annotation",
5252
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
5353
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc",
54-
# pymongo: ValueError: update cannot be empty
55-
"update.tests.SimpleTest.test_empty_update_with_inheritance",
56-
"update.tests.SimpleTest.test_nonempty_update_with_inheritance",
5754
# Pattern lookups that use regexMatch don't work on JSONField:
5855
# Unsupported conversion from array to string in $convert
5956
"model_fields.test_jsonfield.TestQuerying.test_icontains",

django_mongodb/query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def __init__(self, compiler, columns):
5353
def __repr__(self):
5454
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
5555

56-
def fetch(self):
57-
"""Return an iterator over the query results."""
58-
yield from self.get_cursor()
56+
# def fetch(self):
57+
# """Return an iterator over the query results."""
58+
# yield from self.get_cursor()
5959

6060
@wrap_database_errors
6161
def count(self, limit=None, skip=None):

0 commit comments

Comments
 (0)