Skip to content

Commit 3e63192

Browse files
committed
INTPYTHON-348 aggregate via raw_mql
1 parent dc5238a commit 3e63192

File tree

2 files changed

+8
-274
lines changed

2 files changed

+8
-274
lines changed

django_mongodb/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33

44

55
class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
6-
pass
6+
def get_query_set(self):
7+
return MongoQuerySet(self.model, using=self._db)

django_mongodb/query.py

Lines changed: 6 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from django.db.models.sql.constants import INNER, GET_ITERATOR_CHUNK_SIZE
1313
from django.db.models.sql.datastructures import Join
1414
from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode
15+
from django.db.models.sql import Query
1516
from django.utils.functional import cached_property
1617
from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError
1718

@@ -314,280 +315,12 @@ def register_nodes():
314315

315316

316317
class MongoQuerySet(QuerySet):
317-
def raw_mql(self, raw_query, params=(), translations=None, using=None):
318-
if using is None:
319-
using = self.db
320-
qs = RawQuerySet(
321-
raw_query,
322-
model=self.model,
323-
params=params,
324-
translations=translations,
325-
using=using,
326-
)
327-
return qs
318+
def raw_mql(self, raw_query):
319+
return QuerySet(self.model, RawQuery(self.model, raw_query))
328320

329321

330-
class RawQuerySet:
331-
"""
332-
Provide an iterator which converts the results of raw SQL queries into
333-
annotated model instances.
334-
"""
322+
class RawQuery(Query):
335323

336-
def __init__(
337-
self,
338-
raw_query,
339-
model=None,
340-
query=None,
341-
params=(),
342-
translations=None,
343-
using=None,
344-
hints=None,
345-
):
324+
def __init__(self, model, raw_query):
325+
super(RawQuery, self).__init__(model)
346326
self.raw_query = raw_query
347-
self.model = model
348-
self._db = using
349-
self._hints = hints or {}
350-
self.query = query or RawQuery(sql=raw_query, using=self.db, params=params)
351-
self.params = params
352-
self.translations = translations or {}
353-
self._result_cache = None
354-
self._prefetch_related_lookups = ()
355-
self._prefetch_done = False
356-
357-
def resolve_model_init_order(self):
358-
"""Resolve the init field names and value positions."""
359-
converter = connections[self.db].introspection.identifier_converter
360-
model_init_fields = [
361-
f for f in self.model._meta.fields if converter(f.column) in self.columns
362-
]
363-
annotation_fields = [
364-
(column, pos)
365-
for pos, column in enumerate(self.columns)
366-
if column not in self.model_fields
367-
]
368-
model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields]
369-
model_init_names = [f.attname for f in model_init_fields]
370-
return model_init_names, model_init_order, annotation_fields
371-
372-
def prefetch_related(self, *lookups):
373-
"""Same as QuerySet.prefetch_related()"""
374-
clone = self._clone()
375-
if lookups == (None,):
376-
clone._prefetch_related_lookups = ()
377-
else:
378-
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
379-
return clone
380-
381-
def _prefetch_related_objects(self):
382-
prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
383-
self._prefetch_done = True
384-
385-
def _clone(self):
386-
"""Same as QuerySet._clone()"""
387-
c = self.__class__(
388-
self.raw_query,
389-
model=self.model,
390-
query=self.query,
391-
params=self.params,
392-
translations=self.translations,
393-
using=self._db,
394-
hints=self._hints,
395-
)
396-
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
397-
return c
398-
399-
def _fetch_all(self):
400-
if self._result_cache is None:
401-
self._result_cache = list(self.iterator())
402-
if self._prefetch_related_lookups and not self._prefetch_done:
403-
self._prefetch_related_objects()
404-
405-
def __len__(self):
406-
self._fetch_all()
407-
return len(self._result_cache)
408-
409-
def __bool__(self):
410-
self._fetch_all()
411-
return bool(self._result_cache)
412-
413-
def __iter__(self):
414-
self._fetch_all()
415-
return iter(self._result_cache)
416-
417-
def __aiter__(self):
418-
# Remember, __aiter__ itself is synchronous, it's the thing it returns
419-
# that is async!
420-
async def generator():
421-
await sync_to_async(self._fetch_all)()
422-
for item in self._result_cache:
423-
yield item
424-
425-
return generator()
426-
427-
def iterator(self):
428-
yield from RawModelIterable(self)
429-
430-
def __repr__(self):
431-
return "<%s: %s>" % (self.__class__.__name__, self.query)
432-
433-
def __getitem__(self, k):
434-
return list(self)[k]
435-
436-
@property
437-
def db(self):
438-
"""Return the database used if this query is executed now."""
439-
return self._db or router.db_for_read(self.model, **self._hints)
440-
441-
def using(self, alias):
442-
"""Select the database this RawQuerySet should execute against."""
443-
return RawQuerySet(
444-
self.raw_query,
445-
model=self.model,
446-
query=self.query.chain(using=alias),
447-
params=self.params,
448-
translations=self.translations,
449-
using=alias,
450-
)
451-
452-
@cached_property
453-
def columns(self):
454-
"""
455-
A list of model field names in the order they'll appear in the
456-
query results.
457-
"""
458-
columns = self.query.get_columns()
459-
# Adjust any column names which don't match field names
460-
for query_name, model_name in self.translations.items():
461-
# Ignore translations for nonexistent column names
462-
try:
463-
index = columns.index(query_name)
464-
except ValueError:
465-
pass
466-
else:
467-
columns[index] = model_name
468-
return columns
469-
470-
@cached_property
471-
def model_fields(self):
472-
"""A dict mapping column names to model field names."""
473-
converter = connections[self.db].introspection.identifier_converter
474-
model_fields = {}
475-
for field in self.model._meta.fields:
476-
name, column = field.get_attname_column()
477-
model_fields[converter(column)] = field
478-
return model_fields
479-
480-
481-
class RawQuery:
482-
"""A single raw SQL query."""
483-
484-
def __init__(self, sql, using, params=()):
485-
self.params = params
486-
self.sql = sql
487-
self.using = using
488-
self.cursor = None
489-
490-
# Mirror some properties of a normal query so that
491-
# the compiler can be used to process results.
492-
self.low_mark, self.high_mark = 0, None # Used for offset/limit
493-
self.extra_select = {}
494-
self.annotation_select = {}
495-
496-
def chain(self, using):
497-
return self.clone(using)
498-
499-
def clone(self, using):
500-
return RawQuery(self.sql, using, params=self.params)
501-
502-
def get_columns(self):
503-
if self.cursor is None:
504-
self._execute_query()
505-
converter = connections[self.using].introspection.identifier_converter
506-
return [converter(column_meta[0]) for column_meta in self.cursor.description]
507-
508-
def __iter__(self):
509-
# Always execute a new query for a new iterator.
510-
# This could be optimized with a cache at the expense of RAM.
511-
self._execute_query()
512-
if not connections[self.using].features.can_use_chunked_reads:
513-
# If the database can't use chunked reads we need to make sure we
514-
# evaluate the entire query up front.
515-
result = list(self.cursor)
516-
else:
517-
result = self.cursor
518-
return iter(result)
519-
520-
def __repr__(self):
521-
return "<%s: %s>" % (self.__class__.__name__, self)
522-
523-
@property
524-
def params_type(self):
525-
if self.params is None:
526-
return None
527-
return dict if isinstance(self.params, Mapping) else tuple
528-
529-
def __str__(self):
530-
if self.params_type is None:
531-
return self.sql
532-
return self.sql % self.params_type(self.params)
533-
534-
def _execute_query(self):
535-
connection = connections[self.using]
536-
537-
# Adapt parameters to the database, as much as possible considering
538-
# that the target type isn't known. See #17755.
539-
params_type = self.params_type
540-
adapter = connection.ops.adapt_unknown_value
541-
if params_type is tuple:
542-
params = tuple(adapter(val) for val in self.params)
543-
elif params_type is dict:
544-
params = {key: adapter(val) for key, val in self.params.items()}
545-
elif params_type is None:
546-
params = None
547-
else:
548-
raise RuntimeError("Unexpected params type: %s" % params_type)
549-
550-
self.cursor = connection.cursor()
551-
self.cursor.execute(self.sql, params)
552-
553-
554-
class RawModelIterable(BaseIterable):
555-
"""
556-
Iterable that yields a model instance for each row from a raw queryset.
557-
"""
558-
559-
def __iter__(self):
560-
# Cache some things for performance reasons outside the loop.
561-
db = self.queryset.db
562-
query = self.queryset.query
563-
connection = connections[db]
564-
compiler = connection.ops.compiler("SQLCompiler")(query, connection, db)
565-
query_iterator = iter(query)
566-
567-
try:
568-
(
569-
model_init_names,
570-
model_init_pos,
571-
annotation_fields,
572-
) = self.queryset.resolve_model_init_order()
573-
model_cls = self.queryset.model
574-
if model_cls._meta.pk.attname not in model_init_names:
575-
raise exceptions.FieldDoesNotExist("Raw query must include the primary key")
576-
fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
577-
converters = compiler.get_converters(
578-
[f.get_col(f.model._meta.db_table) if f else None for f in fields]
579-
)
580-
if converters:
581-
query_iterator = compiler.apply_converters(query_iterator, converters)
582-
for values in query_iterator:
583-
# Associate fields to values
584-
model_init_values = [values[pos] for pos in model_init_pos]
585-
instance = model_cls.from_db(db, model_init_names, model_init_values)
586-
if annotation_fields:
587-
for column, pos in annotation_fields:
588-
setattr(instance, column, values[pos])
589-
yield instance
590-
finally:
591-
# Done iterating the Query. If it has its own cursor, close it.
592-
if hasattr(query, "cursor") and query.cursor:
593-
query.cursor.close()

0 commit comments

Comments
 (0)