diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index bf2dbc973..d950c8d59 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -1,6 +1,6 @@ from django.core.exceptions import EmptyResultSet, FieldDoesNotExist, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError -from django.db.models import NOT_PROVIDED, Count, Expression +from django.db.models import Count, Expression from django.db.models.aggregates import Aggregate from django.db.models.constants import LOOKUP_SEP from django.db.models.sql import compiler @@ -23,14 +23,21 @@ def execute_sql( # QuerySet.count() if self.query.annotations == {"__count": Count("*")}: return [self.get_count()] - # Specify columns if there are any annotations so that annotations are - # computed via $project. - columns = self.get_columns() if self.query.annotations else None + + columns = self.get_columns() try: - query = self.build_query(columns) + query = self.build_query( + # Avoid $project (columns=None) if unneeded. + columns if self.query.annotations or not self.query.default_cols else None + ) except EmptyResultSet: - return None - return query.fetch() + return iter([]) if result_type == MULTI else None + + return ( + (self._make_result(row, columns) for row in query.fetch()) + if result_type == MULTI + else self._make_result(next(query.fetch()), columns) + ) def results_iter( self, @@ -43,37 +50,23 @@ def results_iter( Return an iterator over the results from executing query given to this compiler. Called by QuerySet methods. """ - columns = self.get_columns() - if results is None: # QuerySet.values() or values_list() - try: - results = self.build_query(columns).fetch() - except EmptyResultSet: - results = [] + results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size) - converters = self.get_converters(columns) - for entity in results: - yield self._make_result(entity, columns, converters, tuple_expected=tuple_expected) + fields = [s[0] for s in self.select[0 : self.col_count]] + converters = self.get_converters(fields) + rows = results + if converters: + rows = self.apply_converters(rows, converters) + if tuple_expected: + rows = map(tuple, rows) + return rows def has_results(self): return bool(self.get_count(check_exists=True)) - def get_converters(self, expressions): - converters = {} - for name_expr in expressions: - try: - name, expr = name_expr - except TypeError: - # e.g., Count("*") - continue - backend_converters = self.connection.ops.get_db_converters(expr) - field_converters = expr.get_db_converters(self.connection) - if backend_converters or field_converters: - converters[name] = backend_converters + field_converters - return converters - - def _make_result(self, entity, columns, converters, tuple_expected=False): + def _make_result(self, entity, columns): """ Decode values for the given fields from the database entity. @@ -82,7 +75,6 @@ def _make_result(self, entity, columns, converters, tuple_expected=False): """ result = [] for name, col in columns: - field = col.field column_alias = getattr(col, "alias", None) obj = ( # Use the related object... @@ -91,16 +83,7 @@ def _make_result(self, entity, columns, converters, tuple_expected=False): if column_alias is not None and column_alias != self.collection_name else entity ) - value = obj.get(name, NOT_PROVIDED) - if value is NOT_PROVIDED: - value = field.get_default() - elif converters: - # Decode values using Django's database converters API. - for converter in converters.get(name, ()): - value = converter(value, col, self.connection) - result.append(value) - if tuple_expected: - result = tuple(result) + result.append(obj.get(name, col.field.get_default())) return result def check_query(self):