@@ -357,9 +357,7 @@ def build_query(self, columns=None):
357
357
if columns is None :
358
358
extra_fields += ordering_fields
359
359
if extra_fields :
360
- query .extra_fields = {
361
- field_name : expr .as_mql (self , self .connection ) for field_name , expr in extra_fields
362
- }
360
+ query .extra_fields = self .get_project_fields (extra_fields , force_expression = True )
363
361
where = self .get_where ()
364
362
try :
365
363
expr = where .as_mql (self , self .connection ) if where else {}
@@ -429,16 +427,18 @@ def _get_aggregate_expressions(self, expr):
429
427
elif hasattr (expr , "get_source_expressions" ):
430
428
stack .extend (expr .get_source_expressions ())
431
429
432
- def get_project_fields (self , columns = None , ordering = None ):
430
+ def get_project_fields (self , columns = None , ordering = None , force_expression = False ):
431
+ if not columns :
432
+ return {}
433
433
fields = defaultdict (dict )
434
- for name , expr in columns or [] :
434
+ for name , expr in columns + ( ordering or ()) :
435
435
collection = expr .alias if isinstance (expr , Col ) else None
436
436
try :
437
437
fields [collection ][name ] = (
438
438
1
439
439
# For brevity/simplicity, project {"field_name": 1}
440
440
# instead of {"field_name": "$field_name"}.
441
- if isinstance (expr , Col ) and name == expr .target .column
441
+ if isinstance (expr , Col ) and name == expr .target .column and not force_expression
442
442
else expr .as_mql (self , self .connection )
443
443
)
444
444
except EmptyResultSet :
@@ -449,9 +449,6 @@ def get_project_fields(self, columns=None, ordering=None):
449
449
# should appear in the top-level of the fields dict.
450
450
fields .update (fields .pop (None , {}))
451
451
fields .update (fields .pop (self .collection_name , {}))
452
- # Add order_by() fields.
453
- if fields and ordering :
454
- fields .update ({alias : expr .as_mql (self , self .connection ) for alias , expr in ordering })
455
452
# Convert defaultdict to dict so it doesn't appear as
456
453
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
457
454
return dict (fields )
@@ -471,7 +468,7 @@ def _get_ordering(self):
471
468
for order in self .order_by_objs or []:
472
469
if isinstance (order .expression , Col ):
473
470
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
474
- fields [field_name ] = order .expression
471
+ fields [order . expression . target . column ] = order .expression
475
472
elif isinstance (order .expression , Ref ):
476
473
field_name = order .expression .as_mql (self , self .connection ).removeprefix ("$" )
477
474
else :
0 commit comments