Skip to content

Commit 4f52df0

Browse files
WaVEVtimgraham
authored andcommitted
conform SQLCompiler.execute_sql() and results_iter() to Django's API
1 parent e347d97 commit 4f52df0

File tree

1 file changed

+25
-42
lines changed

1 file changed

+25
-42
lines changed

django_mongodb/compiler.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.core.exceptions import EmptyResultSet, FieldDoesNotExist, FullResultSet
22
from django.db import DatabaseError, IntegrityError, NotSupportedError
3-
from django.db.models import NOT_PROVIDED, Count, Expression
3+
from django.db.models import Count, Expression
44
from django.db.models.aggregates import Aggregate
55
from django.db.models.constants import LOOKUP_SEP
66
from django.db.models.sql import compiler
@@ -23,14 +23,21 @@ def execute_sql(
2323
# QuerySet.count()
2424
if self.query.annotations == {"__count": Count("*")}:
2525
return [self.get_count()]
26-
# Specify columns if there are any annotations so that annotations are
27-
# computed via $project.
28-
columns = self.get_columns() if self.query.annotations else None
26+
27+
columns = self.get_columns()
2928
try:
30-
query = self.build_query(columns)
29+
query = self.build_query(
30+
# Avoid $project (columns=None) if unneeded.
31+
columns if self.query.annotations or not self.query.default_cols else None
32+
)
3133
except EmptyResultSet:
32-
return None
33-
return query.fetch()
34+
return iter([]) if result_type == MULTI else None
35+
36+
return (
37+
(self._make_result(row, columns) for row in query.fetch())
38+
if result_type == MULTI
39+
else self._make_result(next(query.fetch()), columns)
40+
)
3441

3542
def results_iter(
3643
self,
@@ -43,37 +50,23 @@ def results_iter(
4350
Return an iterator over the results from executing query given
4451
to this compiler. Called by QuerySet methods.
4552
"""
46-
columns = self.get_columns()
47-
4853
if results is None:
4954
# QuerySet.values() or values_list()
50-
try:
51-
results = self.build_query(columns).fetch()
52-
except EmptyResultSet:
53-
results = []
55+
results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
5456

55-
converters = self.get_converters(columns)
56-
for entity in results:
57-
yield self._make_result(entity, columns, converters, tuple_expected=tuple_expected)
57+
fields = [s[0] for s in self.select[0 : self.col_count]]
58+
converters = self.get_converters(fields)
59+
rows = results
60+
if converters:
61+
rows = self.apply_converters(rows, converters)
62+
if tuple_expected:
63+
rows = map(tuple, rows)
64+
return rows
5865

5966
def has_results(self):
6067
return bool(self.get_count(check_exists=True))
6168

62-
def get_converters(self, expressions):
63-
converters = {}
64-
for name_expr in expressions:
65-
try:
66-
name, expr = name_expr
67-
except TypeError:
68-
# e.g., Count("*")
69-
continue
70-
backend_converters = self.connection.ops.get_db_converters(expr)
71-
field_converters = expr.get_db_converters(self.connection)
72-
if backend_converters or field_converters:
73-
converters[name] = backend_converters + field_converters
74-
return converters
75-
76-
def _make_result(self, entity, columns, converters, tuple_expected=False):
69+
def _make_result(self, entity, columns):
7770
"""
7871
Decode values for the given fields from the database entity.
7972
@@ -82,7 +75,6 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
8275
"""
8376
result = []
8477
for name, col in columns:
85-
field = col.field
8678
column_alias = getattr(col, "alias", None)
8779
obj = (
8880
# Use the related object...
@@ -91,16 +83,7 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
9183
if column_alias is not None and column_alias != self.collection_name
9284
else entity
9385
)
94-
value = obj.get(name, NOT_PROVIDED)
95-
if value is NOT_PROVIDED:
96-
value = field.get_default()
97-
elif converters:
98-
# Decode values using Django's database converters API.
99-
for converter in converters.get(name, ()):
100-
value = converter(value, col, self.connection)
101-
result.append(value)
102-
if tuple_expected:
103-
result = tuple(result)
86+
result.append(obj.get(name, col.field.get_default()))
10487
return result
10588

10689
def check_query(self):

0 commit comments

Comments
 (0)