Skip to content

Commit b33bc80

Browse files
committed
Fix group by expressions: remove repeated expressions.
1 parent adc46aa commit b33bc80

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

django_mongodb/compiler.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,9 @@ def __init__(self, *args, **kwargs):
3434
# A list of OrderBy objects for this query.
3535
self.order_by_objs = None
3636
# Subquery parent compiler.
37-
self.parent_collections = set()
3837
self.column_mapping = {}
3938
self.subqueries = []
4039

41-
def get_parent(self):
42-
return self.parent_compiler
43-
4440
def _unfold_column(self, col):
4541
"""
4642
Flatten a field by returning its target or by replacing dots with
@@ -163,23 +159,42 @@ def _prepare_annotations_for_aggregation_pipeline(self, order_by):
163159
group.update(having_group)
164160
return group, replacements
165161

166-
def _get_group_id_expressions(self, order_by):
167-
"""Generate group ID expressions for the aggregation pipeline."""
168-
group_expressions = set()
169-
replacements = {}
162+
def _get_group_expressions(self, order_by):
163+
# The query.group_by is either None (no GROUP BY at all), True
164+
# (group by select fields), or a list of expressions to be added
165+
# to the group by.
166+
if self.query.group_by is None:
167+
return []
168+
seen = set()
169+
expressions = set()
170+
if self.query.group_by is not True:
171+
for expr in self.query.group_by:
172+
if not hasattr(expr, "as_sql"):
173+
expr = self.query.resolve_ref(expr)
174+
if isinstance(expr, Ref):
175+
if expr.refs not in seen:
176+
seen.add(expr.refs)
177+
expressions.add(expr.source)
178+
else:
179+
expressions.add(expr)
180+
for expr, _, alias in self.select:
181+
# Skip members that are already grouped.
182+
if alias not in seen:
183+
expressions |= set(expr.get_group_by_cols())
170184
if not self._meta_ordering:
171185
for expr, (_, _, is_ref) in order_by:
186+
# Skip references.
172187
if not is_ref:
173-
group_expressions |= set(expr.get_group_by_cols())
174-
for expr, *_ in self.select:
175-
group_expressions |= set(expr.get_group_by_cols())
188+
expressions.extend(expr.get_group_by_cols())
176189
having_group_by = self.having.get_group_by_cols() if self.having else ()
177190
for expr in having_group_by:
178-
group_expressions.add(expr)
179-
if isinstance(self.query.group_by, tuple | list):
180-
group_expressions |= set(self.query.group_by)
181-
elif self.query.group_by is None:
182-
group_expressions = set()
191+
expressions.add(expr)
192+
return self.collapse_group_by(expressions, having_group_by)
193+
194+
def _get_group_id_expressions(self, order_by):
195+
"""Generate group ID expressions for the aggregation pipeline."""
196+
replacements = {}
197+
group_expressions = self._get_group_expressions(order_by)
183198
if not group_expressions:
184199
ids = None
185200
else:
@@ -195,6 +210,8 @@ def _get_group_id_expressions(self, order_by):
195210
ids[alias] = Value(True).as_mql(self, self.connection)
196211
if replacement is not None:
197212
replacements[col] = replacement
213+
if isinstance(col, Ref):
214+
replacements[col.source] = replacement
198215
return ids, replacements
199216

200217
def _build_aggregation_pipeline(self, ids, group):

0 commit comments

Comments
 (0)