|
12 | 12 | from django.db.models.sql.constants import INNER, GET_ITERATOR_CHUNK_SIZE
|
13 | 13 | from django.db.models.sql.datastructures import Join
|
14 | 14 | from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode
|
| 15 | +from django.db.models.sql import Query |
15 | 16 | from django.utils.functional import cached_property
|
16 | 17 | from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError
|
17 | 18 |
|
@@ -314,280 +315,12 @@ def register_nodes():
|
314 | 315 |
|
315 | 316 |
|
316 | 317 | class MongoQuerySet(QuerySet):
|
317 |
| - def raw_mql(self, raw_query, params=(), translations=None, using=None): |
318 |
| - if using is None: |
319 |
| - using = self.db |
320 |
| - qs = RawQuerySet( |
321 |
| - raw_query, |
322 |
| - model=self.model, |
323 |
| - params=params, |
324 |
| - translations=translations, |
325 |
| - using=using, |
326 |
| - ) |
327 |
| - return qs |
| 318 | + def raw_mql(self, raw_query): |
| 319 | + return QuerySet(self.model, RawQuery(self.model, raw_query)) |
328 | 320 |
|
329 | 321 |
|
330 |
| -class RawQuerySet: |
331 |
| - """ |
332 |
| - Provide an iterator which converts the results of raw SQL queries into |
333 |
| - annotated model instances. |
334 |
| - """ |
| 322 | +class RawQuery(Query): |
335 | 323 |
|
336 |
| - def __init__( |
337 |
| - self, |
338 |
| - raw_query, |
339 |
| - model=None, |
340 |
| - query=None, |
341 |
| - params=(), |
342 |
| - translations=None, |
343 |
| - using=None, |
344 |
| - hints=None, |
345 |
| - ): |
| 324 | + def __init__(self, model, raw_query): |
| 325 | + super(RawQuery, self).__init__(model) |
346 | 326 | self.raw_query = raw_query
|
347 |
| - self.model = model |
348 |
| - self._db = using |
349 |
| - self._hints = hints or {} |
350 |
| - self.query = query or RawQuery(sql=raw_query, using=self.db, params=params) |
351 |
| - self.params = params |
352 |
| - self.translations = translations or {} |
353 |
| - self._result_cache = None |
354 |
| - self._prefetch_related_lookups = () |
355 |
| - self._prefetch_done = False |
356 |
| - |
357 |
| - def resolve_model_init_order(self): |
358 |
| - """Resolve the init field names and value positions.""" |
359 |
| - converter = connections[self.db].introspection.identifier_converter |
360 |
| - model_init_fields = [ |
361 |
| - f for f in self.model._meta.fields if converter(f.column) in self.columns |
362 |
| - ] |
363 |
| - annotation_fields = [ |
364 |
| - (column, pos) |
365 |
| - for pos, column in enumerate(self.columns) |
366 |
| - if column not in self.model_fields |
367 |
| - ] |
368 |
| - model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields] |
369 |
| - model_init_names = [f.attname for f in model_init_fields] |
370 |
| - return model_init_names, model_init_order, annotation_fields |
371 |
| - |
372 |
| - def prefetch_related(self, *lookups): |
373 |
| - """Same as QuerySet.prefetch_related()""" |
374 |
| - clone = self._clone() |
375 |
| - if lookups == (None,): |
376 |
| - clone._prefetch_related_lookups = () |
377 |
| - else: |
378 |
| - clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups |
379 |
| - return clone |
380 |
| - |
381 |
| - def _prefetch_related_objects(self): |
382 |
| - prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) |
383 |
| - self._prefetch_done = True |
384 |
| - |
385 |
| - def _clone(self): |
386 |
| - """Same as QuerySet._clone()""" |
387 |
| - c = self.__class__( |
388 |
| - self.raw_query, |
389 |
| - model=self.model, |
390 |
| - query=self.query, |
391 |
| - params=self.params, |
392 |
| - translations=self.translations, |
393 |
| - using=self._db, |
394 |
| - hints=self._hints, |
395 |
| - ) |
396 |
| - c._prefetch_related_lookups = self._prefetch_related_lookups[:] |
397 |
| - return c |
398 |
| - |
399 |
| - def _fetch_all(self): |
400 |
| - if self._result_cache is None: |
401 |
| - self._result_cache = list(self.iterator()) |
402 |
| - if self._prefetch_related_lookups and not self._prefetch_done: |
403 |
| - self._prefetch_related_objects() |
404 |
| - |
405 |
| - def __len__(self): |
406 |
| - self._fetch_all() |
407 |
| - return len(self._result_cache) |
408 |
| - |
409 |
| - def __bool__(self): |
410 |
| - self._fetch_all() |
411 |
| - return bool(self._result_cache) |
412 |
| - |
413 |
| - def __iter__(self): |
414 |
| - self._fetch_all() |
415 |
| - return iter(self._result_cache) |
416 |
| - |
417 |
| - def __aiter__(self): |
418 |
| - # Remember, __aiter__ itself is synchronous, it's the thing it returns |
419 |
| - # that is async! |
420 |
| - async def generator(): |
421 |
| - await sync_to_async(self._fetch_all)() |
422 |
| - for item in self._result_cache: |
423 |
| - yield item |
424 |
| - |
425 |
| - return generator() |
426 |
| - |
427 |
| - def iterator(self): |
428 |
| - yield from RawModelIterable(self) |
429 |
| - |
430 |
| - def __repr__(self): |
431 |
| - return "<%s: %s>" % (self.__class__.__name__, self.query) |
432 |
| - |
433 |
| - def __getitem__(self, k): |
434 |
| - return list(self)[k] |
435 |
| - |
436 |
| - @property |
437 |
| - def db(self): |
438 |
| - """Return the database used if this query is executed now.""" |
439 |
| - return self._db or router.db_for_read(self.model, **self._hints) |
440 |
| - |
441 |
| - def using(self, alias): |
442 |
| - """Select the database this RawQuerySet should execute against.""" |
443 |
| - return RawQuerySet( |
444 |
| - self.raw_query, |
445 |
| - model=self.model, |
446 |
| - query=self.query.chain(using=alias), |
447 |
| - params=self.params, |
448 |
| - translations=self.translations, |
449 |
| - using=alias, |
450 |
| - ) |
451 |
| - |
452 |
| - @cached_property |
453 |
| - def columns(self): |
454 |
| - """ |
455 |
| - A list of model field names in the order they'll appear in the |
456 |
| - query results. |
457 |
| - """ |
458 |
| - columns = self.query.get_columns() |
459 |
| - # Adjust any column names which don't match field names |
460 |
| - for query_name, model_name in self.translations.items(): |
461 |
| - # Ignore translations for nonexistent column names |
462 |
| - try: |
463 |
| - index = columns.index(query_name) |
464 |
| - except ValueError: |
465 |
| - pass |
466 |
| - else: |
467 |
| - columns[index] = model_name |
468 |
| - return columns |
469 |
| - |
470 |
| - @cached_property |
471 |
| - def model_fields(self): |
472 |
| - """A dict mapping column names to model field names.""" |
473 |
| - converter = connections[self.db].introspection.identifier_converter |
474 |
| - model_fields = {} |
475 |
| - for field in self.model._meta.fields: |
476 |
| - name, column = field.get_attname_column() |
477 |
| - model_fields[converter(column)] = field |
478 |
| - return model_fields |
479 |
| - |
480 |
| - |
481 |
| -class RawQuery: |
482 |
| - """A single raw SQL query.""" |
483 |
| - |
484 |
| - def __init__(self, sql, using, params=()): |
485 |
| - self.params = params |
486 |
| - self.sql = sql |
487 |
| - self.using = using |
488 |
| - self.cursor = None |
489 |
| - |
490 |
| - # Mirror some properties of a normal query so that |
491 |
| - # the compiler can be used to process results. |
492 |
| - self.low_mark, self.high_mark = 0, None # Used for offset/limit |
493 |
| - self.extra_select = {} |
494 |
| - self.annotation_select = {} |
495 |
| - |
496 |
| - def chain(self, using): |
497 |
| - return self.clone(using) |
498 |
| - |
499 |
| - def clone(self, using): |
500 |
| - return RawQuery(self.sql, using, params=self.params) |
501 |
| - |
502 |
| - def get_columns(self): |
503 |
| - if self.cursor is None: |
504 |
| - self._execute_query() |
505 |
| - converter = connections[self.using].introspection.identifier_converter |
506 |
| - return [converter(column_meta[0]) for column_meta in self.cursor.description] |
507 |
| - |
508 |
| - def __iter__(self): |
509 |
| - # Always execute a new query for a new iterator. |
510 |
| - # This could be optimized with a cache at the expense of RAM. |
511 |
| - self._execute_query() |
512 |
| - if not connections[self.using].features.can_use_chunked_reads: |
513 |
| - # If the database can't use chunked reads we need to make sure we |
514 |
| - # evaluate the entire query up front. |
515 |
| - result = list(self.cursor) |
516 |
| - else: |
517 |
| - result = self.cursor |
518 |
| - return iter(result) |
519 |
| - |
520 |
| - def __repr__(self): |
521 |
| - return "<%s: %s>" % (self.__class__.__name__, self) |
522 |
| - |
523 |
| - @property |
524 |
| - def params_type(self): |
525 |
| - if self.params is None: |
526 |
| - return None |
527 |
| - return dict if isinstance(self.params, Mapping) else tuple |
528 |
| - |
529 |
| - def __str__(self): |
530 |
| - if self.params_type is None: |
531 |
| - return self.sql |
532 |
| - return self.sql % self.params_type(self.params) |
533 |
| - |
534 |
| - def _execute_query(self): |
535 |
| - connection = connections[self.using] |
536 |
| - |
537 |
| - # Adapt parameters to the database, as much as possible considering |
538 |
| - # that the target type isn't known. See #17755. |
539 |
| - params_type = self.params_type |
540 |
| - adapter = connection.ops.adapt_unknown_value |
541 |
| - if params_type is tuple: |
542 |
| - params = tuple(adapter(val) for val in self.params) |
543 |
| - elif params_type is dict: |
544 |
| - params = {key: adapter(val) for key, val in self.params.items()} |
545 |
| - elif params_type is None: |
546 |
| - params = None |
547 |
| - else: |
548 |
| - raise RuntimeError("Unexpected params type: %s" % params_type) |
549 |
| - |
550 |
| - self.cursor = connection.cursor() |
551 |
| - self.cursor.execute(self.sql, params) |
552 |
| - |
553 |
| - |
554 |
| -class RawModelIterable(BaseIterable): |
555 |
| - """ |
556 |
| - Iterable that yields a model instance for each row from a raw queryset. |
557 |
| - """ |
558 |
| - |
559 |
| - def __iter__(self): |
560 |
| - # Cache some things for performance reasons outside the loop. |
561 |
| - db = self.queryset.db |
562 |
| - query = self.queryset.query |
563 |
| - connection = connections[db] |
564 |
| - compiler = connection.ops.compiler("SQLCompiler")(query, connection, db) |
565 |
| - query_iterator = iter(query) |
566 |
| - |
567 |
| - try: |
568 |
| - ( |
569 |
| - model_init_names, |
570 |
| - model_init_pos, |
571 |
| - annotation_fields, |
572 |
| - ) = self.queryset.resolve_model_init_order() |
573 |
| - model_cls = self.queryset.model |
574 |
| - if model_cls._meta.pk.attname not in model_init_names: |
575 |
| - raise exceptions.FieldDoesNotExist("Raw query must include the primary key") |
576 |
| - fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] |
577 |
| - converters = compiler.get_converters( |
578 |
| - [f.get_col(f.model._meta.db_table) if f else None for f in fields] |
579 |
| - ) |
580 |
| - if converters: |
581 |
| - query_iterator = compiler.apply_converters(query_iterator, converters) |
582 |
| - for values in query_iterator: |
583 |
| - # Associate fields to values |
584 |
| - model_init_values = [values[pos] for pos in model_init_pos] |
585 |
| - instance = model_cls.from_db(db, model_init_names, model_init_values) |
586 |
| - if annotation_fields: |
587 |
| - for column, pos in annotation_fields: |
588 |
| - setattr(instance, column, values[pos]) |
589 |
| - yield instance |
590 |
| - finally: |
591 |
| - # Done iterating the Query. If it has its own cursor, close it. |
592 |
| - if hasattr(query, "cursor") and query.cursor: |
593 |
| - query.cursor.close() |
0 commit comments