diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index d4d717fa8..80751ca0b 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,10 +1,12 @@ +from itertools import chain + from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError from django.db.models import Count, Expression from django.db.models.aggregates import Aggregate from django.db.models.expressions import OrderBy from django.db.models.sql import compiler -from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR +from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, ORDER_DIR, SINGLE from django.utils.functional import cached_property from .base import Cursor @@ -33,11 +35,21 @@ def execute_sql( except EmptyResultSet: return iter([]) if result_type == MULTI else None - return ( - (self._make_result(row, columns) for row in query.fetch()) - if result_type == MULTI - else self._make_result(next(query.fetch()), columns) - ) + cursor = query.get_cursor() + if result_type == SINGLE: + try: + obj = cursor.next() + except StopIteration: + return None # No result + else: + return self._make_result(obj, columns) + # result_type is MULTI + cursor.batch_size(chunk_size) + result = self.cursor_iter(cursor, chunk_size, columns) + if not chunked_fetch: + # If using non-chunked reads, read data into memory. + return list(result) + return result def results_iter( self, @@ -49,6 +61,15 @@ def results_iter( """ Return an iterator over the results from executing query given to this compiler. Called by QuerySet methods. + + This method is copied from the superclass with one modification: the + `if tuple_expected` block is deindented so that the result of + _make_result() (a list) is cast to tuple as needed. For SQL database + drivers, tuple results come from cursor.fetchmany(), so the cast is + only needed there when apply_converters() casts the tuple to a list. + This customized method could be removed if _make_result() cast its + return value to a tuple, but that would be more expensive since that + cast is not always needed. """ if results is None: # QuerySet.values() or values_list() @@ -56,7 +77,7 @@ def results_iter( fields = [s[0] for s in self.select[0 : self.col_count]] converters = self.get_converters(fields) - rows = results + rows = chain.from_iterable(results) if converters: rows = self.apply_converters(rows, converters) if tuple_expected: @@ -86,6 +107,16 @@ def _make_result(self, entity, columns): result.append(obj.get(name)) return result + def cursor_iter(self, cursor, chunk_size, columns): + """Yield chunks of results from cursor.""" + chunk = [] + for row in cursor: + chunk.append(self._make_result(row, columns)) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + yield chunk + def check_query(self): """Check if the current query is supported by the database.""" if self.query.is_empty(): diff --git a/django_mongodb/query.py b/django_mongodb/query.py index e48b99363..f941f3c1e 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -53,10 +53,6 @@ def __init__(self, compiler, columns): def __repr__(self): return f"" - def fetch(self): - """Return an iterator over the query results.""" - yield from self.get_cursor() - @wrap_database_errors def count(self, limit=None, skip=None): """