1
1
import itertools
2
2
from collections import defaultdict
3
3
4
+ from bson import SON
4
5
from django .core .exceptions import EmptyResultSet , FullResultSet
5
6
from django .db import DatabaseError , IntegrityError , NotSupportedError
6
7
from django .db .models import Count , Expression
7
8
from django .db .models .aggregates import Aggregate , Variance
8
- from django .db .models .expressions import Col , OrderBy , Value
9
+ from django .db .models .expressions import Col , Ref , Value
9
10
from django .db .models .functions .comparison import Coalesce
10
11
from django .db .models .functions .math import Power
11
12
from django .db .models .sql import compiler
12
- from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , ORDER_DIR , SINGLE
13
+ from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , SINGLE
13
14
from django .utils .functional import cached_property
15
+ from pymongo import ASCENDING , DESCENDING
14
16
15
17
from .base import Cursor
16
18
from .query import MongoQuery , wrap_database_errors
@@ -25,6 +27,8 @@ class SQLCompiler(compiler.SQLCompiler):
25
27
def __init__ (self , * args , ** kwargs ):
26
28
super ().__init__ (* args , ** kwargs )
27
29
self .aggregation_pipeline = None
30
+ # A list of OrderBy objects for this query.
31
+ self .order_by_expressions = None
28
32
29
33
def _get_group_alias_column (self , expr , annotation_group_idx ):
30
34
"""Generate a dummy field for use in the ids fields in $group."""
@@ -98,7 +102,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
98
102
replacements [sub_expr ] = replacing_expr
99
103
return replacements , group
100
104
101
- def _prepare_annotations_for_aggregation_pipeline (self ):
105
+ def _prepare_annotations_for_aggregation_pipeline (self , order_by ):
102
106
"""Prepare annotations for the aggregation pipeline."""
103
107
replacements = {}
104
108
group = {}
@@ -110,6 +114,13 @@ def _prepare_annotations_for_aggregation_pipeline(self):
110
114
)
111
115
replacements .update (new_replacements )
112
116
group .update (expr_group )
117
+ for expr , _ in order_by :
118
+ if expr .contains_aggregate :
119
+ new_replacements , expr_group = self ._prepare_expressions_for_pipeline (
120
+ expr , None , annotation_group_idx
121
+ )
122
+ replacements .update (new_replacements )
123
+ group .update (expr_group )
113
124
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
114
125
self .having , None , annotation_group_idx
115
126
)
@@ -121,9 +132,10 @@ def _get_group_id_expressions(self, order_by):
121
132
"""Generate group ID expressions for the aggregation pipeline."""
122
133
group_expressions = set ()
123
134
replacements = {}
124
- for expr , (_ , _ , is_ref ) in order_by :
125
- if not is_ref :
126
- group_expressions |= set (expr .get_group_by_cols ())
135
+ if not self ._meta_ordering :
136
+ for expr , (_ , _ , is_ref ) in order_by :
137
+ if not is_ref :
138
+ group_expressions |= set (expr .get_group_by_cols ())
127
139
for expr , * _ in self .select :
128
140
group_expressions |= set (expr .get_group_by_cols ())
129
141
having_group_by = self .having .get_group_by_cols () if self .having else ()
@@ -187,7 +199,7 @@ def _build_aggregation_pipeline(self, ids, group):
187
199
188
200
def pre_sql_setup (self , with_col_aliases = False ):
189
201
extra_select , order_by , group_by = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
190
- group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline ()
202
+ group , all_replacements = self ._prepare_annotations_for_aggregation_pipeline (order_by )
191
203
# query.group_by is either:
192
204
# - None: no GROUP BY
193
205
# - True: group by select fields
@@ -211,6 +223,9 @@ def pre_sql_setup(self, with_col_aliases=False):
211
223
target : expr .replace_expressions (all_replacements )
212
224
for target , expr in self .query .annotation_select .items ()
213
225
}
226
+ self .order_by_expressions = [
227
+ expr .replace_expressions (all_replacements ) for expr , _ in order_by
228
+ ]
214
229
return extra_select , order_by , group_by
215
230
216
231
def execute_sql (
@@ -333,8 +348,11 @@ def build_query(self, columns=None):
333
348
self .check_query ()
334
349
query = self .query_class (self )
335
350
query .lookup_pipeline = self .get_lookup_pipeline ()
336
- query .order_by (self ._get_ordering ())
337
- query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
351
+ ordering_fields , sort_ordering , extra_fields = self ._get_ordering ()
352
+ query .project_fields = self .get_project_fields (columns , ordering_fields )
353
+ query .ordering = sort_ordering
354
+ if extra_fields and columns is None :
355
+ query .extra_fields = self .get_project_fields (extra_fields )
338
356
where = self .get_where ()
339
357
try :
340
358
expr = where .as_mql (self , self .connection ) if where else {}
@@ -380,52 +398,6 @@ def project_field(column):
380
398
+ tuple (map (project_field , related_columns ))
381
399
)
382
400
383
- def _get_ordering (self ):
384
- """
385
- Return a list of (field, ascending) tuples that the query results
386
- should be ordered by. If there is no field ordering defined, return
387
- the standard_ordering (a boolean, needed for MongoDB "$natural"
388
- ordering).
389
- """
390
- opts = self .query .get_meta ()
391
- ordering = (
392
- self .query .order_by or opts .ordering
393
- if self .query .default_ordering
394
- else self .query .order_by
395
- )
396
- if not ordering :
397
- return self .query .standard_ordering
398
- default_order , _ = ORDER_DIR ["ASC" if self .query .standard_ordering else "DESC" ]
399
- column_ordering = []
400
- columns_seen = set ()
401
- for order in ordering :
402
- if order == "?" :
403
- raise NotSupportedError ("Randomized ordering isn't supported by MongoDB." )
404
- if hasattr (order , "resolve_expression" ):
405
- # order is an expression like OrderBy, F, or database function.
406
- orderby = order if isinstance (order , OrderBy ) else order .asc ()
407
- orderby = orderby .resolve_expression (self .query , allow_joins = True , reuse = None )
408
- ascending = not orderby .descending
409
- # If the query is reversed, ascending and descending are inverted.
410
- if not self .query .standard_ordering :
411
- ascending = not ascending
412
- else :
413
- # order is a string like "field" or "field__other_field".
414
- orderby , _ = self .find_ordering_name (
415
- order , self .query .get_meta (), default_order = default_order
416
- )[0 ]
417
- ascending = not orderby .descending
418
- column = orderby .expression .as_mql (self , self .connection )
419
- if isinstance (column , dict ):
420
- raise NotSupportedError ("order_by() expression not supported." )
421
- # $sort references must not include the dollar sign.
422
- column = column .removeprefix ("$" )
423
- # Don't add the same column twice.
424
- if column not in columns_seen :
425
- columns_seen .add (column )
426
- column_ordering .append ((column , ascending ))
427
- return column_ordering
428
-
429
401
@cached_property
430
402
def collection_name (self ):
431
403
return self .query .get_meta ().db_table
@@ -473,12 +445,37 @@ def get_project_fields(self, columns=None, ordering=None):
473
445
if self .query .alias_refcount [alias ] and self .collection_name != alias :
474
446
fields [alias ] = 1
475
447
# Add order_by() fields.
476
- for column , _ in ordering or []:
477
- foreign_table = column .split ("." , 1 )[0 ] if "." in column else None
478
- if column not in fields and foreign_table not in fields :
479
- fields [column ] = 1
448
+ for alias , expression in ordering or []:
449
+ nested_entity = alias .split ("." , 1 )[0 ] if "." in alias else None
450
+ if alias not in fields and nested_entity not in fields :
451
+ fields [alias ] = expression . as_mql ( self , self . connection )
480
452
return fields
481
453
454
+ def _get_ordering (self ):
455
+ """
456
+ Process the query's OrderBy objects and return:
457
+ - A tuple of ('field_name': Col/Expression, ...)
458
+ - A bson.SON mapping to pass to $sort.
459
+ - A tuple of ('field_name': Expression, ...) for expressions that need
460
+ to be added to extra_fields.
461
+ """
462
+ fields = {}
463
+ sort_ordering = SON ()
464
+ extra_fields = {}
465
+ idx = itertools .count (start = 1 )
466
+ for order in self .order_by_expressions or []:
467
+ if isinstance (order .expression , Ref ):
468
+ field_name = order .expression .refs
469
+ elif isinstance (order .expression , Col ):
470
+ field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
471
+ else :
472
+ # The expression must be added to extra_fields with an alias.
473
+ field_name = f"__order{ next (idx )} "
474
+ extra_fields [field_name ] = order .expression
475
+ fields [field_name ] = order .expression
476
+ sort_ordering [field_name ] = DESCENDING if order .descending else ASCENDING
477
+ return tuple (fields .items ()), sort_ordering , tuple (extra_fields .items ())
478
+
482
479
def get_where (self ):
483
480
return self .where
484
481
0 commit comments