@@ -315,92 +315,62 @@ def raw_mql(self, pipeline, using=None):
315
315
316
316
class MongoRawQuery (RawQuery ):
317
317
def __init__ (self , pipeline , using , model ):
318
+ self .pipeline = pipeline
318
319
super ().__init__ (sql = None , using = using )
319
320
self .model = model
320
- self .pipeline = pipeline
321
-
322
- def __iter__ (self ):
323
- self ._execute_query ()
324
- return self .cursor
325
321
326
322
def _execute_query (self ):
327
323
connection = connections [self .using ]
328
324
collection = connection .get_collection (self .model ._meta .db_table )
329
325
self .cursor = collection .aggregate (self .pipeline )
330
326
331
- def get_columns (self ):
332
- return [f .column for f in self .model ._meta .fields ]
333
-
334
327
def __str__ (self ):
335
- return "%s" % self .pipeline
328
+ return str ( self .pipeline )
336
329
337
330
338
331
class MongoRawQuerySet (RawQuerySet ):
339
- def __init__ (
340
- self ,
341
- pipeline ,
342
- model = None ,
343
- query = None ,
344
- translations = None ,
345
- using = None ,
346
- hints = None ,
347
- ):
348
- super ().__init__ (
349
- pipeline ,
350
- model = model ,
351
- query = query ,
352
- using = using ,
353
- hints = hints ,
354
- translations = translations ,
355
- )
356
- self .query = query or MongoRawQuery (pipeline , using = self .db , model = self .model )
332
+ def __init__ (self , pipeline , model = None , using = None ):
333
+ super ().__init__ (pipeline , model = model , using = using )
334
+ self .query = MongoRawQuery (pipeline , using = self .db , model = self .model )
335
+ # Override the superclass's columns property which relies on PEP 249's
336
+ # cursor.description. Instead, RawModelIterable will set the columns
337
+ # based on the keys in the first result.
338
+ self .columns = None
357
339
358
340
def iterator (self ):
359
341
yield from MongoRawModelIterable (self )
360
342
361
- def resolve_model_init_order (self , columns ):
362
- """Resolve the init field names and value positions."""
363
- model_init_fields = [f for f in self .model ._meta .fields if f .column in columns ]
364
- annotation_fields = [
365
- (column , pos ) for pos , column in enumerate (columns ) if column not in self .model_fields
366
- ]
367
- model_init_order = [columns .index (f .column ) for f in model_init_fields ]
368
- model_init_names = [f .attname for f in model_init_fields ]
369
- return model_init_names , model_init_order , annotation_fields
370
-
371
343
372
344
class MongoRawModelIterable (RawModelIterable ):
373
- """
374
- Iterable that yields a model instance for each row from a raw queryset.
375
- """
376
-
377
345
def __iter__ (self ):
378
- # Cache some things for performance reasons outside the loop.
346
+ """
347
+ This is mostly copied from the superclass except for the part that
348
+ sets self.queryset.columns from the first document.
349
+ """
379
350
db = self .queryset .db
380
351
query = self .queryset .query
381
352
connection = connections [db ]
382
353
compiler = connection .ops .compiler ("SQLCompiler" )(query , connection , db )
383
354
query_iterator = iter (query )
384
- # Get the columns from the first result.
385
- try :
386
- first_item = next (query_iterator )
387
- except StopIteration :
388
- # No results.
389
- query .cursor .close ()
390
- return
391
- columns = list (first_item .keys ())
392
- # Reset the iterator to include the first item.
393
- query_iterator = self ._make_result (chain ([first_item ], query_iterator ))
394
355
try :
356
+ # Get the columns from the first result.
357
+ try :
358
+ first_item = next (query_iterator )
359
+ except StopIteration :
360
+ # No results.
361
+ return
362
+ self .queryset .columns = list (first_item .keys ())
363
+ # Reset the iterator to include the first item.
364
+ query_iterator = self ._make_result (chain ([first_item ], query_iterator ))
395
365
(
396
366
model_init_names ,
397
367
model_init_pos ,
398
368
annotation_fields ,
399
- ) = self .queryset .resolve_model_init_order (columns )
369
+ ) = self .queryset .resolve_model_init_order ()
400
370
model_cls = self .queryset .model
401
371
if model_cls ._meta .pk .attname not in model_init_names :
402
372
raise FieldDoesNotExist ("Raw query must include the primary key" )
403
- fields = [self .queryset .model_fields .get (c ) for c in columns ]
373
+ fields = [self .queryset .model_fields .get (c ) for c in self . queryset . columns ]
404
374
converters = compiler .get_converters (
405
375
[f .get_col (f .model ._meta .db_table ) if f else None for f in fields ]
406
376
)
@@ -415,9 +385,7 @@ def __iter__(self):
415
385
setattr (instance , column , values [pos ])
416
386
yield instance
417
387
finally :
418
- # Done iterating the Query. If it has its own cursor, close it.
419
- if hasattr (query , "cursor" ) and query .cursor :
420
- query .cursor .close ()
388
+ query .cursor .close ()
421
389
422
390
def _make_result (self , query ):
423
391
for result in query :
0 commit comments