Skip to content

Commit 324d379

Browse files
committed
refactor compiler.
1 parent 72a20d4 commit 324d379

File tree

1 file changed

+33
-39
lines changed

1 file changed

+33
-39
lines changed

django_mongodb/compiler.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.db.models.constants import LOOKUP_SEP
66
from django.db.models.sql import compiler
77
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI
8+
from django.utils.functional import cached_property
89

910
from .base import Cursor
1011
from .query import MongoQuery, wrap_database_errors
@@ -22,14 +23,18 @@ def execute_sql(
2223
# QuerySet.count()
2324
if self.query.annotations == {"__count": Count("*")}:
2425
return [self.get_count()]
25-
# Specify columns if there are any annotations so that annotations are
26-
# computed via $project.
27-
columns = self.get_columns() if self.query.annotations else None
26+
27+
columns = self.get_columns()
2828
try:
2929
query = self.build_query(columns)
3030
except EmptyResultSet:
3131
return None
32-
return query.fetch()
32+
33+
return (
34+
(self._make_result(row, columns) for row in query.fetch())
35+
if result_type == MULTI
36+
else self._make_result(next(query.fetch()), columns)
37+
)
3338

3439
def results_iter(
3540
self,
@@ -42,37 +47,24 @@ def results_iter(
4247
Return an iterator over the results from executing query given
4348
to this compiler. Called by QuerySet methods.
4449
"""
45-
columns = self.get_columns()
4650

4751
if results is None:
4852
# QuerySet.values() or values_list()
49-
try:
50-
results = self.build_query(columns).fetch()
51-
except EmptyResultSet:
52-
results = []
53+
results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size)
5354

54-
converters = self.get_converters(columns)
55-
for entity in results:
56-
yield self._make_result(entity, columns, converters, tuple_expected=tuple_expected)
55+
fields = [s[0] for s in self.select[0 : self.col_count]]
56+
converters = self.get_converters(fields)
57+
rows = results
58+
if converters:
59+
rows = self.apply_converters(rows, converters)
60+
if tuple_expected:
61+
rows = map(tuple, rows)
62+
return rows
5763

5864
def has_results(self):
5965
return bool(self.get_count(check_exists=True))
6066

61-
def get_converters(self, expressions):
62-
converters = {}
63-
for name_expr in expressions:
64-
try:
65-
name, expr = name_expr
66-
except TypeError:
67-
# e.g., Count("*")
68-
continue
69-
backend_converters = self.connection.ops.get_db_converters(expr)
70-
field_converters = expr.get_db_converters(self.connection)
71-
if backend_converters or field_converters:
72-
converters[name] = backend_converters + field_converters
73-
return converters
74-
75-
def _make_result(self, entity, columns, converters, tuple_expected=False):
67+
def _make_result(self, entity, columns):
7668
"""
7769
Decode values for the given fields from the database entity.
7870
@@ -81,17 +73,15 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
8173
"""
8274
result = []
8375
for name, col in columns:
84-
field = col.field
85-
value = entity.get(name, NOT_PROVIDED)
86-
if value is NOT_PROVIDED:
87-
value = field.get_default()
88-
elif converters:
89-
# Decode values using Django's database converters API.
90-
for converter in converters.get(name, ()):
91-
value = converter(value, col, self.connection)
92-
result.append(value)
93-
if tuple_expected:
94-
result = tuple(result)
76+
column_alias = getattr(col, "alias", None)
77+
obj = (
78+
# Use the related object...
79+
entity.get(column_alias, {})
80+
# ...if this column refers to an object for select_related().
81+
if column_alias is not None and column_alias != self.collection_name
82+
else entity
83+
)
84+
result.append(obj.get(name, NOT_PROVIDED))
9585
return result
9686

9787
def check_query(self):
@@ -212,8 +202,12 @@ def _get_ordering(self):
212202
field_ordering.append((opts.get_field(name), ascending))
213203
return field_ordering
214204

205+
@cached_property
206+
def collection_name(self):
207+
return self.query.get_meta().db_table
208+
215209
def get_collection(self):
216-
return self.connection.get_collection(self.query.get_meta().db_table)
210+
return self.connection.get_collection(self.collection_name)
217211

218212

219213
class SQLInsertCompiler(SQLCompiler):

0 commit comments

Comments
 (0)