@@ -358,9 +358,7 @@ def build_query(self, columns=None):
358
358
if columns is None :
359
359
extra_fields += ordering_fields
360
360
if extra_fields :
361
- query .extra_fields = {
362
- field_name : expr .as_mql (self , self .connection ) for field_name , expr in extra_fields
363
- }
361
+ query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
364
362
where = self .get_where ()
365
363
try :
366
364
expr = where .as_mql (self , self .connection ) if where else {}
@@ -431,16 +429,18 @@ def _get_aggregate_expressions(self, expr):
431
429
elif hasattr (expr , "get_source_expressions" ):
432
430
stack .extend (expr .get_source_expressions ())
433
431
434
- def get_project_fields (self , columns = None , ordering = None ):
432
+ def get_project_fields (self , columns = None , ordering = None , force_expression = False ):
433
+ if not columns :
434
+ return {}
435
435
fields = defaultdict (dict )
436
- for name , expr in columns or [] :
436
+ for name , expr in columns + ( ordering or ()) :
437
437
collection = expr .alias if isinstance (expr , Col ) else None
438
438
try :
439
439
fields [collection ][name ] = (
440
440
1
441
441
# For brevity/simplicity, project {"field_name": 1}
442
442
# instead of {"field_name": "$field_name"}.
443
- if isinstance (expr , Col ) and name == expr .target .column
443
+ if isinstance (expr , Col ) and name == expr .target .column and not force_expression
444
444
else expr .as_mql (self , self .connection )
445
445
)
446
446
except EmptyResultSet :
@@ -451,9 +451,6 @@ def get_project_fields(self, columns=None, ordering=None):
451
451
# should appear in the top-level of the fields dict.
452
452
fields .update (fields .pop (None , {}))
453
453
fields .update (fields .pop (self .collection_name , {}))
454
- # Add order_by() fields.
455
- if fields and ordering :
456
- fields .update ({alias : expr .as_mql (self , self .connection ) for alias , expr in ordering })
457
454
# Convert defaultdict to dict so it doesn't appear as
458
455
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
459
456
return dict (fields )
@@ -473,7 +470,7 @@ def _get_ordering(self):
473
470
for order in self .order_by_objs or []:
474
471
if isinstance (order .expression , Col ):
475
472
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
476
- fields [field_name ] = order .expression
473
+ fields [order . expression . target . column ] = order .expression
477
474
elif isinstance (order .expression , Ref ):
478
475
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
479
476
else :
0 commit comments