1
1
import itertools
2
2
3
+ from bson import SON
3
4
from django .core .exceptions import EmptyResultSet , FullResultSet
4
5
from django .db import DatabaseError , IntegrityError , NotSupportedError
5
6
from django .db .models import Count , Expression
6
7
from django .db .models .aggregates import Aggregate , Variance
7
- from django .db .models .expressions import Col , OrderBy , Value
8
+ from django .db .models .expressions import Col , Ref , Value
8
9
from django .db .models .functions .comparison import Coalesce
9
10
from django .db .models .functions .math import Power
10
11
from django .db .models .sql import compiler
11
- from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR , SINGLE
12
+ from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , SINGLE
12
13
from django .utils .functional import cached_property
14
+ from pymongo import ASCENDING , DESCENDING
13
15
14
16
from .base import Cursor
15
17
from .query import MongoQuery , wrap_database_errors
@@ -24,6 +26,7 @@ class SQLCompiler(compiler.SQLCompiler):
24
26
def __init__ (self , * args , ** kwargs ):
25
27
super ().__init__ (* args , ** kwargs )
26
28
self ._group_pipeline = None
29
+ self ._order_by = None
27
30
28
31
def _get_group_alias_column (self , col , annotation_group_idx ):
29
32
"""Generate alias and replacement for group columns."""
@@ -71,7 +74,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, count):
71
74
replacements [sub_expr ] = replacing_expr
72
75
return replacements , group
73
76
74
- def _prepare_annotations_for_group_pipeline (self ):
77
+ def _prepare_annotations_for_group_pipeline (self , order_by ):
75
78
"""Prepare annotations for the MongoDB aggregation pipeline."""
76
79
replacements = {}
77
80
group = {}
@@ -84,6 +87,14 @@ def _prepare_annotations_for_group_pipeline(self):
84
87
replacements .update (new_replacements )
85
88
group .update (expr_group )
86
89
90
+ for expr , _ in order_by :
91
+ if expr .contains_aggregate :
92
+ new_replacements , expr_group = self ._prepare_expressions_for_pipeline (
93
+ expr , None , count
94
+ )
95
+ replacements .update (new_replacements )
96
+ group .update (expr_group )
97
+
87
98
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
88
99
self .having , None , count
89
100
)
@@ -95,9 +106,10 @@ def _get_group_id_expressions(self, order_by):
95
106
"""Generate group ID expressions for the aggregation pipeline."""
96
107
group_expressions = set ()
97
108
replacements = {}
98
- for expr , (_ , _ , is_ref ) in order_by :
99
- if not is_ref :
100
- group_expressions |= set (expr .get_group_by_cols ())
109
+ if not self ._meta_ordering :
110
+ for expr , (_ , _ , is_ref ) in order_by :
111
+ if not is_ref :
112
+ group_expressions |= set (expr .get_group_by_cols ())
101
113
102
114
for expr , * _ in self .select :
103
115
group_expressions |= set (expr .get_group_by_cols ())
@@ -169,7 +181,7 @@ def _build_group_pipeline(self, ids, group):
169
181
170
182
def pre_sql_setup (self , with_col_aliases = False ):
171
183
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
172
- group , all_replacements = self ._prepare_annotations_for_group_pipeline ()
184
+ group , all_replacements = self ._prepare_annotations_for_group_pipeline (order_by )
173
185
174
186
# The query.group_by is either None (no GROUP BY at all), True
175
187
# (group by select fields), or a list of expressions to be added
@@ -190,6 +202,7 @@ def pre_sql_setup(self, with_col_aliases=False):
190
202
)
191
203
self ._group_pipeline = pipeline
192
204
205
+ self ._order_by = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
193
206
self .annotations = {
194
207
target : expr .replace_expressions (all_replacements )
195
208
for target , expr in self .query .annotation_select .items ()
@@ -320,8 +333,13 @@ def build_query(self, columns=None):
320
333
query = self .query_class (self )
321
334
query .aggregation_pipeline = self .get_aggregation_pipeline ()
322
335
query .lookup_pipeline = self .get_lookup_pipeline ()
323
- query .order_by (self ._get_ordering ())
324
- query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
336
+ ordering_fields , order , need_extra_fields = self .preprocess_orderby ()
337
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
338
+ query .ordering = order
339
+ if need_extra_fields and columns is None :
340
+ query .extra_fields = self .get_project_fields (
341
+ ((name , field ) for name , field in ordering_fields if name .startswith ("__order" ))
342
+ )
325
343
try :
326
344
where = getattr (self , "where" , self .query .where )
327
345
query .mongo_query = (
@@ -367,52 +385,6 @@ def project_field(column):
367
385
+ tuple (map (project_field , related_columns ))
368
386
)
369
387
370
- def _get_ordering (self ):
371
- """
372
- Return a list of (field, ascending) tuples that the query results
373
- should be ordered by. If there is no field ordering defined, return
374
- the standard_ordering (a boolean, needed for MongoDB "$natural"
375
- ordering).
376
- """
377
- opts = self .query .get_meta ()
378
- ordering = (
379
- self .query .order_by or opts .ordering
380
- if self .query .default_ordering
381
- else self .query .order_by
382
- )
383
- if not ordering :
384
- return self .query .standard_ordering
385
- default_order , _ = ORDER_DIR ["ASC" if self .query .standard_ordering else "DESC" ]
386
- column_ordering = []
387
- columns_seen = set ()
388
- for order in ordering :
389
- if order == "?" :
390
- raise NotSupportedError ("Randomized ordering isn't supported by MongoDB." )
391
- if hasattr (order , "resolve_expression" ):
392
- # order is an expression like OrderBy, F, or database function.
393
- orderby = order if isinstance (order , OrderBy ) else order .asc ()
394
- orderby = orderby .resolve_expression (self .query , allow_joins = True , reuse = None )
395
- ascending = not orderby .descending
396
- # If the query is reversed, ascending and descending are inverted.
397
- if not self .query .standard_ordering :
398
- ascending = not ascending
399
- else :
400
- # order is a string like "field" or "field__other_field".
401
- orderby , _ = self .find_ordering_name (
402
- order , self .query .get_meta (), default_order = default_order
403
- )[0 ]
404
- ascending = not orderby .descending
405
- column = orderby .expression .as_mql (self , self .connection )
406
- if isinstance (column , dict ):
407
- raise NotSupportedError ("order_by() expression not supported." )
408
- # $sort references must not include the dollar sign.
409
- column = column .removeprefix ("$" )
410
- # Don't add the same column twice.
411
- if column not in columns_seen :
412
- columns_seen .add (column )
413
- column_ordering .append ((column , ascending ))
414
- return column_ordering
415
-
416
388
@cached_property
417
389
def collection_name (self ):
418
390
return self .query .get_meta ().db_table
@@ -464,13 +436,31 @@ def get_project_fields(self, columns=None, ordering=None):
464
436
if self .query .alias_refcount [alias ] and self .collection_name != alias :
465
437
fields [alias ] = 1
466
438
467
- for column , _ in ordering or []:
468
- foreign_table = column .split ("." , 1 )[0 ] if "." in column else None
469
- if column not in fields and foreign_table not in fields :
470
- fields [column ] = 1
439
+ # Add order_by() fields.
440
+ for alias , expression in ordering or []:
441
+ nested_entity = alias .split ("." , 1 )[0 ] if "." in alias else None
442
+ if alias not in fields and nested_entity not in fields :
443
+ fields [alias ] = expression .as_mql (self , self .connection )
471
444
472
445
return fields
473
446
447
+ def preprocess_orderby (self ):
448
+ fields = {}
449
+ result = SON ()
450
+ need_extra_fields = False
451
+ idx = itertools .count (start = 1 )
452
+ for order in self ._order_by or []:
453
+ if isinstance (order .expression , Ref ):
454
+ fieldname = order .expression .refs
455
+ elif isinstance (order .expression , Col ):
456
+ fieldname = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
457
+ else :
458
+ fieldname = f"__order{ next (idx )} "
459
+ need_extra_fields = True
460
+ fields [fieldname ] = order .expression
461
+ result [fieldname ] = DESCENDING if order .descending else ASCENDING
462
+ return tuple (fields .items ()), result , need_extra_fields
463
+
474
464
475
465
class SQLInsertCompiler (SQLCompiler ):
476
466
def execute_sql (self , returning_fields = None ):
0 commit comments