@@ -44,10 +44,8 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
44
44
45
45
def _get_column_from_expression (self , expr , alias ):
46
46
"""
47
- Create a new column with the specified output type and alias to hold the aggregate value.
48
-
49
- This function generates a column target from the given expression, setting the column's
50
- output type and assigning the provided alias to the column.
47
+ Create a column named `alias` from the given expression to hold the
48
+ aggregate value.
51
49
"""
52
50
column_target = expr .output_field .__class__ ()
53
51
column_target .db_column = alias
@@ -58,17 +56,18 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
58
56
"""
59
57
Prepare expressions for the aggregation pipeline.
60
58
61
- This function handles the computation of aggregation functions used by various expressions.
62
- It separates and creates intermediate columns, and replaces nodes
63
- to simulate a group by operation.
64
-
65
- In MongoDB, the `$group` stage does not allow operations over the aggregator
66
- (e.g., `COALESCE(AVG(field), 3)`). However, it does support operations inside
67
- the aggregation (e.g., `AVG(number * 2)`).
68
- This function manages first cases by splitting the computation into stages:
69
- it computes the aggregation function first, then applies additional operations
70
- in a subsequent stage by replacing the aggregate expressions
71
- with new columns prefixed by `__aggregation`.
59
+ Handle the computation of aggregation functions used by various
60
+ expressions. Separate and create intermediate columns, and replace
61
+ nodes to simulate a group by operation.
62
+
63
+ MongoDB's $group stage doesn't allow operations over the aggregator,
64
+ e.g. COALESCE(AVG(field), 3). However, it supports operations inside
65
+ the aggregation, e.g. AVG(number * 2).
66
+
67
+ Handle the first case by splitting the computation into stages: compute
68
+ the aggregation first, then applies additional operations in a
69
+ subsequent stage by replacing the aggregate expressions with new
70
+ columns prefixed by `__aggregation`.
72
71
"""
73
72
replacements = {}
74
73
group = {}
@@ -81,6 +80,8 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
81
80
column_target .set_attributes_from_name (alias )
82
81
inner_column = Col (self .collection_name , column_target )
83
82
if sub_expr .distinct :
83
+ # If the expression should return distinct values, use
84
+ # $addToSet to deduplicate.
84
85
rhs = sub_expr .as_mql (self , self .connection , resolve_inner_expression = True )
85
86
group [alias ] = {"$addToSet" : rhs }
86
87
replacing_expr = sub_expr .copy ()
@@ -98,7 +99,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
98
99
return replacements , group
99
100
100
101
def _prepare_annotations_for_aggregation_pipeline (self ):
101
- """Prepare annotations for the MongoDB aggregation pipeline."""
102
+ """Prepare annotations for the aggregation pipeline."""
102
103
replacements = {}
103
104
group = {}
104
105
annotation_group_idx = itertools .count (start = 1 )
@@ -109,7 +110,6 @@ def _prepare_annotations_for_aggregation_pipeline(self):
109
110
)
110
111
replacements .update (new_replacements )
111
112
group .update (expr_group )
112
-
113
113
having_replacements , having_group = self ._prepare_expressions_for_pipeline (
114
114
self .having , None , annotation_group_idx
115
115
)
@@ -124,18 +124,15 @@ def _get_group_id_expressions(self, order_by):
124
124
for expr , (_ , _ , is_ref ) in order_by :
125
125
if not is_ref :
126
126
group_expressions |= set (expr .get_group_by_cols ())
127
-
128
127
for expr , * _ in self .select :
129
128
group_expressions |= set (expr .get_group_by_cols ())
130
-
131
129
having_group_by = self .having .get_group_by_cols () if self .having else ()
132
130
for expr in having_group_by :
133
131
group_expressions .add (expr )
134
132
if isinstance (self .query .group_by , tuple | list ):
135
133
group_expressions |= set (self .query .group_by )
136
134
elif self .query .group_by is None :
137
135
group_expressions = set ()
138
-
139
136
if not group_expressions :
140
137
ids = None
141
138
else :
@@ -151,7 +148,6 @@ def _get_group_id_expressions(self, order_by):
151
148
ids [alias ] = Value (True ).as_mql (self , self .connection )
152
149
if replacement is not None :
153
150
replacements [col ] = replacement
154
-
155
151
return ids , replacements
156
152
157
153
def _build_aggregation_pipeline (self , ids , group ):
@@ -180,15 +176,13 @@ def _build_aggregation_pipeline(self, ids, group):
180
176
for key in ids :
181
177
value = f"$_id.{ key } "
182
178
if self .GROUP_SEPARATOR in key :
183
- subtable , field = key .split (self .GROUP_SEPARATOR )
184
- projected_fields [subtable ][field ] = value
179
+ table , field = key .split (self .GROUP_SEPARATOR )
180
+ projected_fields [table ][field ] = value
185
181
else :
186
182
projected_fields [key ] = value
187
-
188
183
pipeline .append ({"$addFields" : projected_fields })
189
184
if "_id" not in projected_fields :
190
185
pipeline .append ({"$unset" : "_id" })
191
-
192
186
return pipeline
193
187
194
188
def pre_sql_setup (self , with_col_aliases = False ):
@@ -213,7 +207,6 @@ def pre_sql_setup(self, with_col_aliases=False):
213
207
}
214
208
)
215
209
self .aggregation_pipeline = pipeline
216
-
217
210
self .annotations = {
218
211
target : expr .replace_expressions (all_replacements )
219
212
for target , expr in self .query .annotation_select .items ()
0 commit comments