Skip to content

Commit dc5238a

Browse files
committed
INTPYTHON-348 aggregate via raw_mql
1 parent c4645c9 commit dc5238a

File tree

2 files changed

+292
-2
lines changed

2 files changed

+292
-2
lines changed

django_mongodb/manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.db.models.manager import BaseManager
2+
from .query import MongoQuerySet
3+
4+
5+
class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
6+
pass

django_mongodb/query.py

Lines changed: 286 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from functools import reduce, wraps
22
from operator import add as add_operator
3+
from collections.abc import Mapping
34

45
from django.core.exceptions import EmptyResultSet, FullResultSet
5-
from django.db import DatabaseError, IntegrityError, NotSupportedError
6+
from django.db import DatabaseError, IntegrityError, NotSupportedError, connections
7+
from django.db.models import QuerySet
68
from django.db.models.expressions import Case, Col, When
79
from django.db.models.functions import Mod
810
from django.db.models.lookups import Exact
9-
from django.db.models.sql.constants import INNER
11+
from django.db.models.query import BaseIterable
12+
from django.db.models.sql.constants import INNER, GET_ITERATOR_CHUNK_SIZE
1013
from django.db.models.sql.datastructures import Join
1114
from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode
15+
from django.utils.functional import cached_property
1216
from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError
1317

1418

@@ -307,3 +311,283 @@ def register_nodes():
307311
Join.as_mql = join
308312
NothingNode.as_mql = NothingNode.as_sql
309313
WhereNode.as_mql = where_node
314+
315+
316+
class MongoQuerySet(QuerySet):
317+
def raw_mql(self, raw_query, params=(), translations=None, using=None):
318+
if using is None:
319+
using = self.db
320+
qs = RawQuerySet(
321+
raw_query,
322+
model=self.model,
323+
params=params,
324+
translations=translations,
325+
using=using,
326+
)
327+
return qs
328+
329+
330+
class RawQuerySet:
331+
"""
332+
Provide an iterator which converts the results of raw SQL queries into
333+
annotated model instances.
334+
"""
335+
336+
def __init__(
337+
self,
338+
raw_query,
339+
model=None,
340+
query=None,
341+
params=(),
342+
translations=None,
343+
using=None,
344+
hints=None,
345+
):
346+
self.raw_query = raw_query
347+
self.model = model
348+
self._db = using
349+
self._hints = hints or {}
350+
self.query = query or RawQuery(sql=raw_query, using=self.db, params=params)
351+
self.params = params
352+
self.translations = translations or {}
353+
self._result_cache = None
354+
self._prefetch_related_lookups = ()
355+
self._prefetch_done = False
356+
357+
def resolve_model_init_order(self):
358+
"""Resolve the init field names and value positions."""
359+
converter = connections[self.db].introspection.identifier_converter
360+
model_init_fields = [
361+
f for f in self.model._meta.fields if converter(f.column) in self.columns
362+
]
363+
annotation_fields = [
364+
(column, pos)
365+
for pos, column in enumerate(self.columns)
366+
if column not in self.model_fields
367+
]
368+
model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields]
369+
model_init_names = [f.attname for f in model_init_fields]
370+
return model_init_names, model_init_order, annotation_fields
371+
372+
def prefetch_related(self, *lookups):
373+
"""Same as QuerySet.prefetch_related()"""
374+
clone = self._clone()
375+
if lookups == (None,):
376+
clone._prefetch_related_lookups = ()
377+
else:
378+
clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
379+
return clone
380+
381+
def _prefetch_related_objects(self):
382+
prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
383+
self._prefetch_done = True
384+
385+
def _clone(self):
386+
"""Same as QuerySet._clone()"""
387+
c = self.__class__(
388+
self.raw_query,
389+
model=self.model,
390+
query=self.query,
391+
params=self.params,
392+
translations=self.translations,
393+
using=self._db,
394+
hints=self._hints,
395+
)
396+
c._prefetch_related_lookups = self._prefetch_related_lookups[:]
397+
return c
398+
399+
def _fetch_all(self):
400+
if self._result_cache is None:
401+
self._result_cache = list(self.iterator())
402+
if self._prefetch_related_lookups and not self._prefetch_done:
403+
self._prefetch_related_objects()
404+
405+
def __len__(self):
406+
self._fetch_all()
407+
return len(self._result_cache)
408+
409+
def __bool__(self):
410+
self._fetch_all()
411+
return bool(self._result_cache)
412+
413+
def __iter__(self):
414+
self._fetch_all()
415+
return iter(self._result_cache)
416+
417+
def __aiter__(self):
418+
# Remember, __aiter__ itself is synchronous, it's the thing it returns
419+
# that is async!
420+
async def generator():
421+
await sync_to_async(self._fetch_all)()
422+
for item in self._result_cache:
423+
yield item
424+
425+
return generator()
426+
427+
def iterator(self):
428+
yield from RawModelIterable(self)
429+
430+
def __repr__(self):
431+
return "<%s: %s>" % (self.__class__.__name__, self.query)
432+
433+
def __getitem__(self, k):
434+
return list(self)[k]
435+
436+
@property
437+
def db(self):
438+
"""Return the database used if this query is executed now."""
439+
return self._db or router.db_for_read(self.model, **self._hints)
440+
441+
def using(self, alias):
442+
"""Select the database this RawQuerySet should execute against."""
443+
return RawQuerySet(
444+
self.raw_query,
445+
model=self.model,
446+
query=self.query.chain(using=alias),
447+
params=self.params,
448+
translations=self.translations,
449+
using=alias,
450+
)
451+
452+
@cached_property
453+
def columns(self):
454+
"""
455+
A list of model field names in the order they'll appear in the
456+
query results.
457+
"""
458+
columns = self.query.get_columns()
459+
# Adjust any column names which don't match field names
460+
for query_name, model_name in self.translations.items():
461+
# Ignore translations for nonexistent column names
462+
try:
463+
index = columns.index(query_name)
464+
except ValueError:
465+
pass
466+
else:
467+
columns[index] = model_name
468+
return columns
469+
470+
@cached_property
471+
def model_fields(self):
472+
"""A dict mapping column names to model field names."""
473+
converter = connections[self.db].introspection.identifier_converter
474+
model_fields = {}
475+
for field in self.model._meta.fields:
476+
name, column = field.get_attname_column()
477+
model_fields[converter(column)] = field
478+
return model_fields
479+
480+
481+
class RawQuery:
482+
"""A single raw SQL query."""
483+
484+
def __init__(self, sql, using, params=()):
485+
self.params = params
486+
self.sql = sql
487+
self.using = using
488+
self.cursor = None
489+
490+
# Mirror some properties of a normal query so that
491+
# the compiler can be used to process results.
492+
self.low_mark, self.high_mark = 0, None # Used for offset/limit
493+
self.extra_select = {}
494+
self.annotation_select = {}
495+
496+
def chain(self, using):
497+
return self.clone(using)
498+
499+
def clone(self, using):
500+
return RawQuery(self.sql, using, params=self.params)
501+
502+
def get_columns(self):
503+
if self.cursor is None:
504+
self._execute_query()
505+
converter = connections[self.using].introspection.identifier_converter
506+
return [converter(column_meta[0]) for column_meta in self.cursor.description]
507+
508+
def __iter__(self):
509+
# Always execute a new query for a new iterator.
510+
# This could be optimized with a cache at the expense of RAM.
511+
self._execute_query()
512+
if not connections[self.using].features.can_use_chunked_reads:
513+
# If the database can't use chunked reads we need to make sure we
514+
# evaluate the entire query up front.
515+
result = list(self.cursor)
516+
else:
517+
result = self.cursor
518+
return iter(result)
519+
520+
def __repr__(self):
521+
return "<%s: %s>" % (self.__class__.__name__, self)
522+
523+
@property
524+
def params_type(self):
525+
if self.params is None:
526+
return None
527+
return dict if isinstance(self.params, Mapping) else tuple
528+
529+
def __str__(self):
530+
if self.params_type is None:
531+
return self.sql
532+
return self.sql % self.params_type(self.params)
533+
534+
def _execute_query(self):
535+
connection = connections[self.using]
536+
537+
# Adapt parameters to the database, as much as possible considering
538+
# that the target type isn't known. See #17755.
539+
params_type = self.params_type
540+
adapter = connection.ops.adapt_unknown_value
541+
if params_type is tuple:
542+
params = tuple(adapter(val) for val in self.params)
543+
elif params_type is dict:
544+
params = {key: adapter(val) for key, val in self.params.items()}
545+
elif params_type is None:
546+
params = None
547+
else:
548+
raise RuntimeError("Unexpected params type: %s" % params_type)
549+
550+
self.cursor = connection.cursor()
551+
self.cursor.execute(self.sql, params)
552+
553+
554+
class RawModelIterable(BaseIterable):
555+
"""
556+
Iterable that yields a model instance for each row from a raw queryset.
557+
"""
558+
559+
def __iter__(self):
560+
# Cache some things for performance reasons outside the loop.
561+
db = self.queryset.db
562+
query = self.queryset.query
563+
connection = connections[db]
564+
compiler = connection.ops.compiler("SQLCompiler")(query, connection, db)
565+
query_iterator = iter(query)
566+
567+
try:
568+
(
569+
model_init_names,
570+
model_init_pos,
571+
annotation_fields,
572+
) = self.queryset.resolve_model_init_order()
573+
model_cls = self.queryset.model
574+
if model_cls._meta.pk.attname not in model_init_names:
575+
raise exceptions.FieldDoesNotExist("Raw query must include the primary key")
576+
fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
577+
converters = compiler.get_converters(
578+
[f.get_col(f.model._meta.db_table) if f else None for f in fields]
579+
)
580+
if converters:
581+
query_iterator = compiler.apply_converters(query_iterator, converters)
582+
for values in query_iterator:
583+
# Associate fields to values
584+
model_init_values = [values[pos] for pos in model_init_pos]
585+
instance = model_cls.from_db(db, model_init_names, model_init_values)
586+
if annotation_fields:
587+
for column, pos in annotation_fields:
588+
setattr(instance, column, values[pos])
589+
yield instance
590+
finally:
591+
# Done iterating the Query. If it has its own cursor, close it.
592+
if hasattr(query, "cursor") and query.cursor:
593+
query.cursor.close()

0 commit comments

Comments
 (0)