@@ -34,13 +34,9 @@ def __init__(self, *args, **kwargs):
34
34
# A list of OrderBy objects for this query.
35
35
self .order_by_objs = None
36
36
# Subquery parent compiler.
37
- self .parent_collections = set ()
38
37
self .column_mapping = {}
39
38
self .subqueries = []
40
39
41
- def get_parent (self ):
42
- return self .parent_compiler
43
-
44
40
def _unfold_column (self , col ):
45
41
"""
46
42
Flatten a field by returning its target or by replacing dots with
@@ -163,23 +159,42 @@ def _prepare_annotations_for_aggregation_pipeline(self, order_by):
163
159
group .update (having_group )
164
160
return group , replacements
165
161
166
- def _get_group_id_expressions (self , order_by ):
167
- """Generate group ID expressions for the aggregation pipeline."""
168
- group_expressions = set ()
169
- replacements = {}
162
+ def _get_group_expressions (self , order_by ):
163
+ # The query.group_by is either None (no GROUP BY at all), True
164
+ # (group by select fields), or a list of expressions to be added
165
+ # to the group by.
166
+ if self .query .group_by is None :
167
+ return []
168
+ seen = set ()
169
+ expressions = set ()
170
+ if self .query .group_by is not True :
171
+ for expr in self .query .group_by :
172
+ if not hasattr (expr , "as_sql" ):
173
+ expr = self .query .resolve_ref (expr )
174
+ if isinstance (expr , Ref ):
175
+ if expr .refs not in seen :
176
+ seen .add (expr .refs )
177
+ expressions .add (expr .source )
178
+ else :
179
+ expressions .add (expr )
180
+ for expr , _ , alias in self .select :
181
+ # Skip members that are already grouped.
182
+ if alias not in seen :
183
+ expressions |= set (expr .get_group_by_cols ())
170
184
if not self ._meta_ordering :
171
185
for expr , (_ , _ , is_ref ) in order_by :
186
+ # Skip references.
172
187
if not is_ref :
173
- group_expressions |= set (expr .get_group_by_cols ())
174
- for expr , * _ in self .select :
175
- group_expressions |= set (expr .get_group_by_cols ())
188
+ expressions .extend (expr .get_group_by_cols ())
176
189
having_group_by = self .having .get_group_by_cols () if self .having else ()
177
190
for expr in having_group_by :
178
- group_expressions .add (expr )
179
- if isinstance (self .query .group_by , tuple | list ):
180
- group_expressions |= set (self .query .group_by )
181
- elif self .query .group_by is None :
182
- group_expressions = set ()
191
+ expressions .add (expr )
192
+ return self .collapse_group_by (expressions , having_group_by )
193
+
194
+ def _get_group_id_expressions (self , order_by ):
195
+ """Generate group ID expressions for the aggregation pipeline."""
196
+ replacements = {}
197
+ group_expressions = self ._get_group_expressions (order_by )
183
198
if not group_expressions :
184
199
ids = None
185
200
else :
@@ -195,6 +210,8 @@ def _get_group_id_expressions(self, order_by):
195
210
ids [alias ] = Value (True ).as_mql (self , self .connection )
196
211
if replacement is not None :
197
212
replacements [col ] = replacement
213
+ if isinstance (col , Ref ):
214
+ replacements [col .source ] = replacement
198
215
return ids , replacements
199
216
200
217
def _build_aggregation_pipeline (self , ids , group ):
0 commit comments