@@ -358,9 +358,7 @@ def build_query(self, columns=None):
358
358
if columns is None :
359
359
extra_fields += ordering_fields
360
360
if extra_fields :
361
- query .extra_fields = {
362
- field_name : expr .as_mql (self , self .connection ) for field_name , expr in extra_fields
363
- }
361
+ query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
364
362
where = self .get_where ()
365
363
try :
366
364
expr = where .as_mql (self , self .connection ) if where else {}
@@ -431,16 +429,18 @@ def _get_aggregate_expressions(self, expr):
431
429
elif hasattr (expr , "get_source_expressions" ):
432
430
stack .extend (expr .get_source_expressions ())
433
431
434
- def get_project_fields (self , columns = None , ordering = None ):
432
+ def get_project_fields (self , columns = None , ordering = None , force_expression = False ):
433
+ if not columns :
434
+ return {}
435
435
fields = defaultdict (dict )
436
- for name , expr in columns or [] :
436
+ for name , expr in columns + ( ordering or ()) :
437
437
collection = expr .alias if isinstance (expr , Col ) else None
438
438
try :
439
439
fields [collection ][name ] = (
440
440
1
441
441
# For brevity/simplicity, project {"field_name": 1}
442
442
# instead of {"field_name": "$field_name"}.
443
- if isinstance (expr , Col ) and name == expr .target .column
443
+ if isinstance (expr , Col ) and name == expr .target .column and not force_expression
444
444
else expr .as_mql (self , self .connection )
445
445
)
446
446
except EmptyResultSet :
@@ -451,9 +451,6 @@ def get_project_fields(self, columns=None, ordering=None):
451
451
# should appear in the top-level of the fields dict.
452
452
fields .update (fields .pop (None , {}))
453
453
fields .update (fields .pop (self .collection_name , {}))
454
- # Add order_by() fields.
455
- if fields and ordering :
456
- fields .update ({alias : expr .as_mql (self , self .connection ) for alias , expr in ordering })
457
454
# Convert defaultdict to dict so it doesn't appear as
458
455
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
459
456
return dict (fields )
@@ -466,28 +463,28 @@ def _get_ordering(self):
466
463
- A tuple of ('field_name': Expression, ...) for expressions that need
467
464
to be added to extra_fields.
468
465
"""
469
- fields = {}
466
+ fields = []
470
467
sort_ordering = SON ()
471
- extra_fields = {}
468
+ extra_fields = []
472
469
idx = itertools .count (start = 1 )
473
470
for order in self .order_by_objs or []:
474
471
if isinstance (order .expression , Col ):
475
472
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
476
- fields [ field_name ] = order .expression
473
+ fields . append (( order . expression . target . column , order .expression ))
477
474
elif isinstance (order .expression , Ref ):
478
475
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
479
476
else :
480
477
field_name = f"__order{ next (idx )} "
481
- fields [ field_name ] = order .expression
478
+ fields . append (( field_name , order .expression ))
482
479
# If the expression is ordered by NULLS FIRST or NULLS LAST,
483
480
# add a field for sorting that's 1 if null or 0 if not.
484
481
if order .nulls_first or order .nulls_last :
485
482
null_fieldname = f"__order{ next (idx )} "
486
483
condition = When (IsNull (order .expression , True ), then = Value (1 ))
487
- extra_fields [ null_fieldname ] = Case (condition , default = Value (0 ))
484
+ extra_fields . append (( null_fieldname , Case (condition , default = Value (0 )) ))
488
485
sort_ordering [null_fieldname ] = DESCENDING if order .nulls_first else ASCENDING
489
486
sort_ordering [field_name ] = DESCENDING if order .descending else ASCENDING
490
- return tuple (fields . items ()) , sort_ordering , tuple (extra_fields . items () )
487
+ return tuple (fields ) , sort_ordering , tuple (extra_fields )
491
488
492
489
def get_where (self ):
493
490
return self .where
0 commit comments