Skip to content

Commit b0a8296

Browse files
committed
made QuerySet iteration respect chunk_size
1 parent c05c3a2 commit b0a8296

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

django_mongodb/compiler.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from itertools import chain
2+
13
from django.core.exceptions import EmptyResultSet, FullResultSet
24
from django.db import DatabaseError, IntegrityError, NotSupportedError
35
from django.db.models import Count, Expression
46
from django.db.models.aggregates import Aggregate
57
from django.db.models.expressions import OrderBy
68
from django.db.models.sql import compiler
7-
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR
9+
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE
810
from django.utils.functional import cached_property
911

1012
from .base import Cursor
@@ -33,11 +35,21 @@ def execute_sql(
3335
except EmptyResultSet:
3436
return iter([]) if result_type == MULTI else None
3537

36-
return (
37-
(self._make_result(row, columns) for row in query.fetch())
38-
if result_type == MULTI
39-
else self._make_result(next(query.fetch()), columns)
40-
)
38+
cursor = query.get_cursor()
39+
if result_type == SINGLE:
40+
try:
41+
obj = cursor.next()
42+
except StopIteration:
43+
return None # No result
44+
else:
45+
return self._make_result(obj, columns)
46+
# result_type is MULTI
47+
cursor.batch_size(chunk_size)
48+
result = self.cursor_iter(cursor, chunk_size, columns)
49+
if not chunked_fetch:
50+
# If using non-chunked reads, read data into memory.
51+
return list(result)
52+
return result
4153

4254
def results_iter(
4355
self,
@@ -49,14 +61,23 @@ def results_iter(
4961
"""
5062
Return an iterator over the results from executing query given
5163
to this compiler. Called by QuerySet methods.
64+
65+
This method is copied from the superclass with one modification: the
66+
`if tuple_expected` block is deindented so that the result of
67+
_make_result() (a list) is cast to tuple as needed. For SQL database
68+
drivers, tuple results come from cursor.fetchmany(), so the cast is
69+
only needed there when apply_converters() casts the tuple to a list.
70+
This customized method could be removed if _make_result() cast its
71+
return value to a tuple, but that would be more expensive since that
72+
cast is not always needed.
5273
"""
5374
if results is None:
5475
# QuerySet.values() or values_list()
5576
results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
5677

5778
fields = [s[0] for s in self.select[0 : self.col_count]]
5879
converters = self.get_converters(fields)
59-
rows = results
80+
rows = chain.from_iterable(results)
6081
if converters:
6182
rows = self.apply_converters(rows, converters)
6283
if tuple_expected:
@@ -86,6 +107,16 @@ def _make_result(self, entity, columns):
86107
result.append(obj.get(name))
87108
return result
88109

110+
def cursor_iter(self, cursor, chunk_size, columns):
111+
"""Yield chunks of results from cursor."""
112+
chunk = []
113+
for row in cursor:
114+
chunk.append(self._make_result(row, columns))
115+
if len(chunk) == chunk_size:
116+
yield chunk
117+
chunk = []
118+
yield chunk
119+
89120
def check_query(self):
90121
"""Check if the current query is supported by the database."""
91122
if self.query.is_empty():
@@ -293,8 +324,14 @@ def check_query(self):
293324
)
294325

295326

296-
class SQLUpdateCompiler(SQLCompiler):
327+
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
297328
def execute_sql(self, result_type):
329+
"""
330+
Execute the specified update. Return the number of rows affected by
331+
the primary update query. The "primary update query" is the first
332+
non-empty query that is executed. Row counts for any subsequent,
333+
related queries are not available.
334+
"""
298335
self.pre_sql_setup()
299336
values = []
300337
for field, _, value in self.query.values:
@@ -309,7 +346,14 @@ def execute_sql(self, result_type):
309346
)
310347
prepared = field.get_db_prep_save(value, connection=self.connection)
311348
values.append((field, prepared))
312-
return self.update(values)
349+
is_empty = not bool(values)
350+
rows = 0 if is_empty else self.update(values)
351+
for query in self.query.get_related_updates():
352+
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
353+
if is_empty and aux_rows:
354+
rows = aux_rows
355+
is_empty = False
356+
return rows
313357

314358
def update(self, values):
315359
spec = {}

django_mongodb/query.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ def __init__(self, compiler, columns):
5353
def __repr__(self):
5454
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
5555

56-
def fetch(self):
57-
"""Return an iterator over the query results."""
58-
yield from self.get_cursor()
59-
6056
@wrap_database_errors
6157
def count(self, limit=None, skip=None):
6258
"""

0 commit comments

Comments
 (0)