@@ -34,12 +34,19 @@ def __init__(self, *args, **kwargs):
34
34
def _get_group_alias_column (self , expr , annotation_group_idx ):
35
35
"""Generate a dummy field for use in the ids fields in $group."""
36
36
replacement = None
37
- if isinstance (expr , Col ):
38
- col = expr
37
+
38
+ # Unwrap Ref (in this part of the pipeline we can't do references over $projected fields).
39
+ expr_ = expr
40
+ while isinstance (expr_ , Ref ):
41
+ expr_ = expr_ .source
42
+ replacement = expr_
43
+
44
+ if isinstance (expr_ , Col ):
45
+ col = expr_
39
46
else :
40
47
# If the column is a composite expression, create a field for it.
41
48
alias = f"__annotation_group{ next (annotation_group_idx )} "
42
- col = self ._get_column_from_expression (expr , alias )
49
+ col = self ._get_column_from_expression (expr_ , alias )
43
50
replacement = col
44
51
if self .collection_name == col .alias :
45
52
return col .target .column , replacement
@@ -137,14 +144,21 @@ def _get_group_id_expressions(self, order_by):
137
144
for expr , (_ , _ , is_ref ) in order_by :
138
145
if not is_ref :
139
146
group_expressions |= set (expr .get_group_by_cols ())
140
- for expr , * _ in self .select :
141
- group_expressions |= set (expr .get_group_by_cols ())
142
147
having_group_by = self .having .get_group_by_cols () if self .having else ()
143
148
for expr in having_group_by :
144
149
group_expressions .add (expr )
150
+
151
+ refs_viewed = set ()
152
+ for expr , _ , alias in self .select :
153
+ group_expressions |= set (expr .get_group_by_cols ())
154
+ refs_viewed .add (alias )
145
155
if isinstance (self .query .group_by , tuple | list ):
146
- group_expressions |= set (self .query .group_by )
147
- elif self .query .group_by is None :
156
+ for expr in self .query .group_by :
157
+ if not isinstance (expr , Ref ) or expr .refs not in refs_viewed :
158
+ group_expressions .add (expr )
159
+ if isinstance (expr , Ref ):
160
+ refs_viewed .add (expr .refs )
161
+ if self .query .group_by is None :
148
162
group_expressions = set ()
149
163
if not group_expressions :
150
164
ids = None
@@ -428,32 +442,22 @@ def _get_aggregate_expressions(self, expr):
428
442
stack .extend (expr .get_source_expressions ())
429
443
430
444
def get_project_fields (self , columns = None , ordering = None ):
431
- fields = {}
445
+ fields = defaultdict ( dict )
432
446
for name , expr in columns or []:
447
+ collection = expr .alias if isinstance (expr , Col ) else None
433
448
try :
434
- column = expr .target .column
435
- except AttributeError :
436
- # Generate the MQL for an annotation.
437
- try :
438
- fields [name ] = expr .as_mql (self , self .connection )
439
- except EmptyResultSet :
440
- fields [name ] = Value (False ).as_mql (self , self .connection )
441
- except FullResultSet :
442
- fields [name ] = Value (True ).as_mql (self , self .connection )
443
- else :
444
- # If name != column, then this is an annotatation referencing
445
- # another column.
446
- fields [name ] = 1 if name == column else f"${ column } "
447
- if fields :
448
- # Add related fields.
449
- for alias in self .query .alias_map :
450
- if self .query .alias_refcount [alias ] and self .collection_name != alias :
451
- fields [alias ] = 1
449
+ fields [collection ][name ] = expr .as_mql (self , self .connection )
450
+ except EmptyResultSet :
451
+ fields [collection ][name ] = Value (False ).as_mql (self , self .connection )
452
+ except FullResultSet :
453
+ fields [collection ][name ] = Value (True ).as_mql (self , self .connection )
454
+ # Unwrap annotations.
455
+ fields .update (fields .pop (None , {}))
456
+ # Unwrap main collection's fields.
457
+ fields .update (fields .pop (self .collection_name , {}))
458
+ if fields and ordering :
452
459
# Add order_by() fields.
453
- for alias , expression in ordering or []:
454
- nested_entity = alias .split ("." , 1 )[0 ] if "." in alias else None
455
- if alias not in fields and nested_entity not in fields :
456
- fields [alias ] = expression .as_mql (self , self .connection )
460
+ fields .update ({alias : expr .as_mql (self , self .connection ) for alias , expr in ordering })
457
461
return fields
458
462
459
463
def _get_ordering (self ):
0 commit comments