1
1
import itertools
2
2
from collections import defaultdict
3
3
4
+ from bson import SON
4
5
from django .core .exceptions import EmptyResultSet , FullResultSet
5
6
from django .db import DatabaseError , IntegrityError , NotSupportedError
6
7
from django .db .models import Count , Expression
7
8
from django .db .models .aggregates import Aggregate , Variance
8
- from django .db .models .expressions import Col , OrderBy , Value
9
+ from django .db .models .expressions import Col , Ref , Value
9
10
from django .db .models .functions .comparison import Coalesce
10
11
from django .db .models .functions .math import Power
11
12
from django .db .models .sql import compiler
12
- from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR , SINGLE
13
+ from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , SINGLE
13
14
from django .utils .functional import cached_property
15
+ from pymongo import ASCENDING , DESCENDING
14
16
15
17
from .base import Cursor
16
18
from .query import MongoQuery , wrap_database_errors
@@ -25,6 +27,7 @@ class SQLCompiler(compiler.SQLCompiler):
25
27
def __init__ (self , * args , ** kwargs ):
26
28
super ().__init__ (* args , ** kwargs )
27
29
self .aggregation_pipeline = None
30
+ self ._order_by = None
28
31
29
32
def _get_group_alias_column (self , expr , annotation_group_idx ):
30
33
"""Generate a dummy field for use in the ids fields in $group."""
@@ -98,7 +101,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
98
101
replacements [sub_expr ] = replacing_expr
99
102
return replacements , group
100
103
101
- def _prepare_annotations_for_aggregation_pipeline (self ):
104
+ def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
102
105
"""Prepare annotations for the aggregation pipeline."""
103
106
replacements = {}
104
107
group = {}
@@ -110,6 +113,13 @@ def _prepare_annotations_for_aggregation_pipeline(self):
110
113
)
111
114
replacements .update (new_replacements )
112
115
group .update (expr_group )
116
+ for expr , _ in order_by :
117
+ if expr .contains_aggregate :
118
+ new_replacements , expr_group = self ._prepare_expressions_for_pipeline (
119
+ expr , None , annotation_group_idx
120
+ )
121
+ replacements .update (new_replacements )
122
+ group .update (expr_group )
113
123
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
114
124
self .having , None , annotation_group_idx
115
125
)
@@ -121,9 +131,10 @@ def _get_group_id_expressions(self, order_by):
121
131
"""Generate group ID expressions for the aggregation pipeline."""
122
132
group_expressions = set ()
123
133
replacements = {}
124
- for expr , (_ , _ , is_ref ) in order_by :
125
- if not is_ref :
126
- group_expressions |= set (expr .get_group_by_cols ())
134
+ if not self ._meta_ordering :
135
+ for expr , (_ , _ , is_ref ) in order_by :
136
+ if not is_ref :
137
+ group_expressions |= set (expr .get_group_by_cols ())
127
138
for expr , * _ in self .select :
128
139
group_expressions |= set (expr .get_group_by_cols ())
129
140
having_group_by = self .having .get_group_by_cols () if self .having else ()
@@ -187,7 +198,7 @@ def _build_aggregation_pipeline(self, ids, group):
187
198
188
199
def pre_sql_setup (self , with_col_aliases = False ):
189
200
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
190
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline ()
201
+ group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
191
202
# query.group_by is either:
192
203
# - None: no GROUP BY
193
204
# - True: group by select fields
@@ -207,6 +218,7 @@ def pre_sql_setup(self, with_col_aliases=False):
207
218
}
208
219
)
209
220
self .aggregation_pipeline = pipeline
221
+ self ._order_by = [expr .replace_expressions (all_replacements ) for expr , _ in order_by ]
210
222
self .annotations = {
211
223
target : expr .replace_expressions (all_replacements )
212
224
for target , expr in self .query .annotation_select .items ()
@@ -333,8 +345,13 @@ def build_query(self, columns=None):
333
345
self .check_query ()
334
346
query = self .query_class (self )
335
347
query .lookup_pipeline = self .get_lookup_pipeline ()
336
- query .order_by (self ._get_ordering ())
337
- query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
348
+ ordering_fields , order , need_extra_fields = self .preprocess_orderby ()
349
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
350
+ query .ordering = order
351
+ if need_extra_fields and columns is None :
352
+ query .extra_fields = self .get_project_fields (
353
+ ((name , field ) for name , field in ordering_fields if name .startswith ("__order" ))
354
+ )
338
355
where = self .get_where ()
339
356
try :
340
357
expr = where .as_mql (self , self .connection ) if where else {}
@@ -380,52 +397,6 @@ def project_field(column):
380
397
+ tuple (map (project_field , related_columns ))
381
398
)
382
399
383
- def _get_ordering (self ):
384
- """
385
- Return a list of (field, ascending) tuples that the query results
386
- should be ordered by. If there is no field ordering defined, return
387
- the standard_ordering (a boolean, needed for MongoDB "$natural"
388
- ordering).
389
- """
390
- opts = self .query .get_meta ()
391
- ordering = (
392
- self .query .order_by or opts .ordering
393
- if self .query .default_ordering
394
- else self .query .order_by
395
- )
396
- if not ordering :
397
- return self .query .standard_ordering
398
- default_order , _ = ORDER_DIR ["ASC" if self .query .standard_ordering else "DESC" ]
399
- column_ordering = []
400
- columns_seen = set ()
401
- for order in ordering :
402
- if order == "?" :
403
- raise NotSupportedError ("Randomized ordering isn't supported by MongoDB." )
404
- if hasattr (order , "resolve_expression" ):
405
- # order is an expression like OrderBy, F, or database function.
406
- orderby = order if isinstance (order , OrderBy ) else order .asc ()
407
- orderby = orderby .resolve_expression (self .query , allow_joins = True , reuse = None )
408
- ascending = not orderby .descending
409
- # If the query is reversed, ascending and descending are inverted.
410
- if not self .query .standard_ordering :
411
- ascending = not ascending
412
- else :
413
- # order is a string like "field" or "field__other_field".
414
- orderby , _ = self .find_ordering_name (
415
- order , self .query .get_meta (), default_order = default_order
416
- )[0 ]
417
- ascending = not orderby .descending
418
- column = orderby .expression .as_mql (self , self .connection )
419
- if isinstance (column , dict ):
420
- raise NotSupportedError ("order_by() expression not supported." )
421
- # $sort references must not include the dollar sign.
422
- column = column .removeprefix ("$" )
423
- # Don't add the same column twice.
424
- if column not in columns_seen :
425
- columns_seen .add (column )
426
- column_ordering .append ((column , ascending ))
427
- return column_ordering
428
-
429
400
@cached_property
430
401
def collection_name (self ):
431
402
return self .query .get_meta ().db_table
@@ -473,12 +444,29 @@ def get_project_fields(self, columns=None, ordering=None):
473
444
if self .query .alias_refcount [alias ] and self .collection_name != alias :
474
445
fields [alias ] = 1
475
446
# Add order_by() fields.
476
- for column , _ in ordering or []:
477
- foreign_table = column .split ("." , 1 )[0 ] if "." in column else None
478
- if column not in fields and foreign_table not in fields :
479
- fields [column ] = 1
447
+ for alias , expression in ordering or []:
448
+ nested_entity = alias .split ("." , 1 )[0 ] if "." in alias else None
449
+ if alias not in fields and nested_entity not in fields :
450
+ fields [alias ] = expression . as_mql ( self , self . connection )
480
451
return fields
481
452
453
+ def preprocess_orderby (self ):
454
+ fields = {}
455
+ result = SON ()
456
+ need_extra_fields = False
457
+ idx = itertools .count (start = 1 )
458
+ for order in self ._order_by or []:
459
+ if isinstance (order .expression , Ref ):
460
+ fieldname = order .expression .refs
461
+ elif isinstance (order .expression , Col ):
462
+ fieldname = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
463
+ else :
464
+ fieldname = f"__order{ next (idx )} "
465
+ need_extra_fields = True
466
+ fields [fieldname ] = order .expression
467
+ result [fieldname ] = DESCENDING if order .descending else ASCENDING
468
+ return tuple (fields .items ()), result , need_extra_fields
469
+
482
470
def get_where (self ):
483
471
return self .where
484
472
0 commit comments