Skip to content

Commit 626442b

Browse files
committed
tidy code and remove unncessary overrides
1 parent 27d5dbb commit 626442b

File tree

1 file changed

+25
-57
lines changed

1 file changed

+25
-57
lines changed

django_mongodb/query.py

Lines changed: 25 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -315,92 +315,62 @@ def raw_mql(self, pipeline, using=None):
315315

316316
class MongoRawQuery(RawQuery):
317317
def __init__(self, pipeline, using, model):
318+
self.pipeline = pipeline
318319
super().__init__(sql=None, using=using)
319320
self.model = model
320-
self.pipeline = pipeline
321-
322-
def __iter__(self):
323-
self._execute_query()
324-
return self.cursor
325321

326322
def _execute_query(self):
327323
connection = connections[self.using]
328324
collection = connection.get_collection(self.model._meta.db_table)
329325
self.cursor = collection.aggregate(self.pipeline)
330326

331-
def get_columns(self):
332-
return [f.column for f in self.model._meta.fields]
333-
334327
def __str__(self):
335-
return "%s" % self.pipeline
328+
return str(self.pipeline)
336329

337330

338331
class MongoRawQuerySet(RawQuerySet):
339-
def __init__(
340-
self,
341-
pipeline,
342-
model=None,
343-
query=None,
344-
translations=None,
345-
using=None,
346-
hints=None,
347-
):
348-
super().__init__(
349-
pipeline,
350-
model=model,
351-
query=query,
352-
using=using,
353-
hints=hints,
354-
translations=translations,
355-
)
356-
self.query = query or MongoRawQuery(pipeline, using=self.db, model=self.model)
332+
def __init__(self, pipeline, model=None, using=None):
333+
super().__init__(pipeline, model=model, using=using)
334+
self.query = MongoRawQuery(pipeline, using=self.db, model=self.model)
335+
# Override the superclass's columns property which relies on PEP 249's
336+
# cursor.description. Instead, RawModelIterable will set the columns
337+
# based on the keys in the first result.
338+
self.columns = None
357339

358340
def iterator(self):
359341
yield from MongoRawModelIterable(self)
360342

361-
def resolve_model_init_order(self, columns):
362-
"""Resolve the init field names and value positions."""
363-
model_init_fields = [f for f in self.model._meta.fields if f.column in columns]
364-
annotation_fields = [
365-
(column, pos) for pos, column in enumerate(columns) if column not in self.model_fields
366-
]
367-
model_init_order = [columns.index(f.column) for f in model_init_fields]
368-
model_init_names = [f.attname for f in model_init_fields]
369-
return model_init_names, model_init_order, annotation_fields
370-
371343

372344
class MongoRawModelIterable(RawModelIterable):
373-
"""
374-
Iterable that yields a model instance for each row from a raw queryset.
375-
"""
376-
377345
def __iter__(self):
378-
# Cache some things for performance reasons outside the loop.
346+
"""
347+
This is mostly copied from the superclass except for the part that
348+
sets self.queryset.columns from the first document.
349+
"""
379350
db = self.queryset.db
380351
query = self.queryset.query
381352
connection = connections[db]
382353
compiler = connection.ops.compiler("SQLCompiler")(query, connection, db)
383354
query_iterator = iter(query)
384-
# Get the columns from the first result.
385-
try:
386-
first_item = next(query_iterator)
387-
except StopIteration:
388-
# No results.
389-
query.cursor.close()
390-
return
391-
columns = list(first_item.keys())
392-
# Reset the iterator to include the first item.
393-
query_iterator = self._make_result(chain([first_item], query_iterator))
394355
try:
356+
# Get the columns from the first result.
357+
try:
358+
first_item = next(query_iterator)
359+
except StopIteration:
360+
# No results.
361+
return
362+
self.queryset.columns = list(first_item.keys())
363+
# Reset the iterator to include the first item.
364+
query_iterator = self._make_result(chain([first_item], query_iterator))
395365
(
396366
model_init_names,
397367
model_init_pos,
398368
annotation_fields,
399-
) = self.queryset.resolve_model_init_order(columns)
369+
) = self.queryset.resolve_model_init_order()
400370
model_cls = self.queryset.model
401371
if model_cls._meta.pk.attname not in model_init_names:
402372
raise FieldDoesNotExist("Raw query must include the primary key")
403-
fields = [self.queryset.model_fields.get(c) for c in columns]
373+
fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
404374
converters = compiler.get_converters(
405375
[f.get_col(f.model._meta.db_table) if f else None for f in fields]
406376
)
@@ -415,9 +385,7 @@ def __iter__(self):
415385
setattr(instance, column, values[pos])
416386
yield instance
417387
finally:
418-
# Done iterating the Query. If it has its own cursor, close it.
419-
if hasattr(query, "cursor") and query.cursor:
420-
query.cursor.close()
388+
query.cursor.close()
421389

422390
def _make_result(self, query):
423391
for result in query:

0 commit comments

Comments
 (0)