5
5
from django .db import DatabaseError , IntegrityError , NotSupportedError
6
6
from django .db .models import Count , Expression
7
7
from django .db .models .aggregates import Aggregate , Variance
8
- from django .db .models .expressions import Col , Ref , Value
8
+ from django .db .models .expressions import Case , Col , Ref , Value , When
9
9
from django .db .models .functions .comparison import Coalesce
10
10
from django .db .models .functions .math import Power
11
+ from django .db .models .lookups import IsNull
11
12
from django .db .models .sql import compiler
12
13
from django .db .models .sql .constants import GET_ITERATOR_CHUNK_SIZE , MULTI , SINGLE
13
14
from django .utils .functional import cached_property
@@ -333,13 +334,14 @@ def build_query(self, columns=None):
333
334
query = self .query_class (self )
334
335
query .aggregation_pipeline = self .get_aggregation_pipeline ()
335
336
query .lookup_pipeline = self .get_lookup_pipeline ()
336
- ordering_fields , order , need_extra_fields = self .preprocess_orderby ()
337
+ orderby_annotations , ordering_fields , order = self .preprocess_orderby ()
337
338
query .project_fields = self .get_project_fields (columns , ordering_fields )
338
339
query .ordering = order
339
- if need_extra_fields and columns is None :
340
- query .extra_fields = self .get_project_fields (
341
- ((name , field ) for name , field in ordering_fields if name .startswith ("__order" ))
342
- )
340
+
341
+ # Post pipeline fields, some of them need some refs to be compted, so we add this fields
342
+ # after the main part of the pipeline has finished.
343
+ if orderby_annotations :
344
+ query .extra_fields = self .get_project_fields (orderby_annotations , add_fields = True )
343
345
try :
344
346
where = getattr (self , "where" , self .query .where )
345
347
query .mongo_query = (
@@ -412,7 +414,7 @@ def _get_aggregate_expressions(self, expr):
412
414
def get_aggregation_pipeline (self ):
413
415
return self ._group_pipeline
414
416
415
- def get_project_fields (self , columns = None , ordering = None ):
417
+ def get_project_fields (self , columns = None , ordering = None , add_fields = False ):
416
418
fields = {}
417
419
for name , expr in columns or []:
418
420
try :
@@ -430,7 +432,7 @@ def get_project_fields(self, columns=None, ordering=None):
430
432
# another column.
431
433
fields [name ] = 1 if name == column else f"${ column } "
432
434
433
- if fields :
435
+ if fields and not add_fields :
434
436
# Add related fields.
435
437
for alias in self .query .alias_map :
436
438
if self .query .alias_refcount [alias ] and self .collection_name != alias :
@@ -446,20 +448,25 @@ def get_project_fields(self, columns=None, ordering=None):
446
448
447
449
def preprocess_orderby (self ):
448
450
fields = {}
451
+ orderby_annotations = {}
449
452
result = SON ()
450
- need_extra_fields = False
451
453
idx = itertools .count (start = 1 )
452
454
for order in self ._order_by or []:
453
- if isinstance (order .expression , Ref ):
454
- fieldname = order .expression .refs
455
- elif isinstance (order .expression , Col ):
455
+ if isinstance (order .expression , Col | Ref ):
456
456
fieldname = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
457
+ fields [fieldname ] = order .expression
457
458
else :
458
459
fieldname = f"__order{ next (idx )} "
459
- need_extra_fields = True
460
- fields [fieldname ] = order .expression
460
+ orderby_annotations [fieldname ] = order .expression
461
+
462
+ if order .nulls_first or order .nulls_last :
463
+ null_fieldname = f"__order{ next (idx )} "
464
+ condition = When (IsNull (order .expression , True ), then = Value (1 ))
465
+ orderby_annotations [null_fieldname ] = Case (condition , default = Value (0 ))
466
+ result [null_fieldname ] = DESCENDING if order .nulls_first else ASCENDING
467
+
461
468
result [fieldname ] = DESCENDING if order .descending else ASCENDING
462
- return tuple (fields .items ()), result , need_extra_fields
469
+ return tuple (orderby_annotations .items ()), tuple ( fields . items ()), result
463
470
464
471
465
472
class SQLInsertCompiler (SQLCompiler ):
0 commit comments