Skip to content

Commit 0b040c0

Browse files
Make old aggregates use new block method
1 parent bf3c6bb commit 0b040c0

File tree

9 files changed

+85
-113
lines changed

9 files changed

+85
-113
lines changed

bigframes/core/array_value.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,26 @@ def compute_general_reduction(
323323
resulting aggregate columns.
324324
"""
325325
plan = self.node
326+
327+
# shortcircuit to keep things simple if all aggs are simple
328+
# TODO: Fully unify paths once rewriters are strong enough to simplify complexity from full path
329+
def _is_direct_agg(agg_expr):
330+
return isinstance(agg_expr, agg_expressions.Aggregation) and all(
331+
isinstance(child, (ex.DerefOp, ex.ScalarConstantExpression))
332+
for child in agg_expr.children
333+
)
334+
335+
if all(_is_direct_agg(agg) for agg in assignments):
336+
agg_defs = tuple((agg, ids.ColumnId.unique()) for agg in assignments)
337+
return ArrayValue(
338+
nodes.AggregateNode(
339+
child=self.node,
340+
aggregations=agg_defs, # type: ignore
341+
by_column_ids=tuple(map(ex.deref, by_column_ids)),
342+
dropna=dropna,
343+
)
344+
)
345+
326346
if dropna:
327347
for col_id in by_column_ids:
328348
plan = nodes.FilterNode(plan, ops.notnull_op.as_expr(col_id))

bigframes/core/block_transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def quantile(
129129
window_spec=window,
130130
)
131131
quantile_cols.append(quantile_col)
132-
block, _ = block.aggregate(
133-
grouping_column_ids,
132+
block = block.aggregate(
134133
tuple(
135134
agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
136135
for col in quantile_cols
137136
),
137+
grouping_column_ids,
138138
column_labels=pd.Index(labels),
139139
dropna=dropna,
140140
)
@@ -358,12 +358,12 @@ def value_counts(
358358
if grouping_keys and drop_na:
359359
# only need this if grouping_keys is involved, otherwise the drop_na in the aggregation will handle it for us
360360
block = dropna(block, columns, how="any")
361-
block, agg_ids = block.aggregate(
362-
by_column_ids=(*grouping_keys, *columns),
361+
block = block.aggregate(
363362
aggregations=[agg_expressions.NullaryAggregation(agg_ops.size_op)],
363+
by_column_ids=(*grouping_keys, *columns),
364364
dropna=drop_na and not grouping_keys,
365365
)
366-
count_id = agg_ids[0]
366+
count_id = block.value_columns[0]
367367
if normalize:
368368
unbound_window = windows.unbound(grouping_keys=tuple(grouping_keys))
369369
block, total_count_id = block.apply_window_op(
@@ -641,7 +641,7 @@ def skew(
641641
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
642642
aggregations.append(skew_expr)
643643

644-
block, _ = block.reduce_general(
644+
block = block.aggregate(
645645
aggregations, grouping_column_ids, column_labels=column_labels
646646
)
647647
if not grouping_column_ids:
@@ -674,7 +674,7 @@ def kurt(
674674
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
675675
kurt_exprs.append(kurt_expr)
676676

677-
block, _ = block.reduce_general(
677+
block = block.aggregate(
678678
kurt_exprs, grouping_column_ids, column_labels=column_labels
679679
)
680680
if not grouping_column_ids:

bigframes/core/blocks.py

Lines changed: 25 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,45 +1169,52 @@ def project_block_exprs(
11691169
index_labels=self._index_labels,
11701170
)
11711171

1172-
def reduce_general(
1172+
def aggregate(
11731173
self,
11741174
aggregations: typing.Sequence[ex.Expression] = (),
11751175
by_column_ids: typing.Sequence[str] = (),
11761176
column_labels: Optional[pd.Index] = None,
11771177
*,
11781178
dropna: bool = True,
1179-
) -> typing.Tuple[Block, typing.Sequence[str]]:
1179+
) -> Block:
11801180
"""
1181-
Version of the aggregate that supports mixing analytic and scalar expressions.
1181+
Apply aggregations to the block.
1182+
1183+
Grouping columns will form the index of the result block.
1184+
1185+
Arguments:
1186+
aggregations: Aggregation expressions to apply
1187+
by_column_id: column id of the aggregation key, this is preserved through the transform and used as index.
1188+
dropna: whether null keys should be dropped
1189+
1190+
Returns:
1191+
Block
11821192
"""
11831193
if column_labels is None:
11841194
column_labels = pd.Index(range(len(aggregations)))
11851195

1186-
result_expr, output_col_ids = self.expr.compute_general_reduction(
1196+
result_expr = self.expr.compute_general_reduction(
11871197
aggregations, by_column_ids, dropna=dropna
11881198
)
11891199

1190-
names: typing.List[Label] = []
1200+
grouping_col_labels: typing.List[Label] = []
11911201
if len(by_column_ids) == 0:
11921202
result_expr, label_id = result_expr.create_constant(0, pd.Int64Dtype())
11931203
index_columns = (label_id,)
1194-
names = [None]
1204+
grouping_col_labels = [None]
11951205
else:
11961206
index_columns = tuple(by_column_ids) # type: ignore
11971207
for by_col_id in by_column_ids:
11981208
if by_col_id in self.value_columns:
1199-
names.append(self.col_id_to_label[by_col_id])
1209+
grouping_col_labels.append(self.col_id_to_label[by_col_id])
12001210
else:
1201-
names.append(self.col_id_to_index_name[by_col_id])
1211+
grouping_col_labels.append(self.col_id_to_index_name[by_col_id])
12021212

1203-
return (
1204-
Block(
1205-
result_expr,
1206-
index_columns=index_columns,
1207-
column_labels=column_labels,
1208-
index_labels=names,
1209-
),
1210-
[id.name for id in output_col_ids],
1213+
return Block(
1214+
result_expr,
1215+
index_columns=index_columns,
1216+
column_labels=column_labels,
1217+
index_labels=grouping_col_labels,
12111218
)
12121219

12131220
def apply_window_op(
@@ -1419,63 +1426,6 @@ def remap_f(x):
14191426
col_labels.append(remap_f(col_label))
14201427
return self.with_column_labels(col_labels)
14211428

1422-
def aggregate(
1423-
self,
1424-
by_column_ids: typing.Sequence[str] = (),
1425-
aggregations: typing.Sequence[agg_expressions.Aggregation] = (),
1426-
column_labels: Optional[pd.Index] = None,
1427-
*,
1428-
dropna: bool = True,
1429-
) -> typing.Tuple[Block, typing.Sequence[str]]:
1430-
"""
1431-
Apply aggregations to the block.
1432-
1433-
Arguments:
1434-
by_column_id: column id of the aggregation key, this is preserved through the transform and used as index.
1435-
aggregations: input_column_id, operation tuples
1436-
dropna: whether null keys should be dropped
1437-
1438-
Returns:
1439-
Tuple[Block, Sequence[str]]:
1440-
The first element is the grouped block. The second is the
1441-
column IDs corresponding to each applied aggregation.
1442-
"""
1443-
if column_labels is None:
1444-
column_labels = pd.Index(range(len(aggregations)))
1445-
1446-
agg_specs = [
1447-
(
1448-
aggregation,
1449-
guid.generate_guid(),
1450-
)
1451-
for aggregation in aggregations
1452-
]
1453-
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
1454-
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)
1455-
1456-
names: typing.List[Label] = []
1457-
if len(by_column_ids) == 0:
1458-
result_expr, label_id = result_expr.create_constant(0, pd.Int64Dtype())
1459-
index_columns = (label_id,)
1460-
names = [None]
1461-
else:
1462-
index_columns = tuple(by_column_ids) # type: ignore
1463-
for by_col_id in by_column_ids:
1464-
if by_col_id in self.value_columns:
1465-
names.append(self.col_id_to_label[by_col_id])
1466-
else:
1467-
names.append(self.col_id_to_index_name[by_col_id])
1468-
1469-
return (
1470-
Block(
1471-
result_expr,
1472-
index_columns=index_columns,
1473-
column_labels=column_labels,
1474-
index_labels=names,
1475-
),
1476-
output_col_ids,
1477-
)
1478-
14791429
def get_stat(
14801430
self,
14811431
column_id: str,
@@ -1835,7 +1785,7 @@ def pivot(
18351785
agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id))
18361786
for col_id in column_ids
18371787
]
1838-
result_block, _ = block.aggregate(
1788+
result_block = block.aggregate(
18391789
by_column_ids=self.index_columns,
18401790
aggregations=aggregations,
18411791
dropna=True,
@@ -2289,7 +2239,7 @@ def _get_unique_values(
22892239
self.select_columns(columns), columns
22902240
)
22912241
else:
2292-
unique_value_block, _ = self.aggregate(by_column_ids=columns, dropna=False)
2242+
unique_value_block = self.aggregate(by_column_ids=columns, dropna=False)
22932243
col_labels = self._get_labels_for_columns(columns)
22942244
unique_value_block = unique_value_block.reset_index(
22952245
drop=False

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def corr(
304304
uniq_orig_columns = utils.combine_indices(labels, pd.Index(range(len(labels))))
305305
result_labels = utils.cross_indices(uniq_orig_columns, uniq_orig_columns)
306306

307-
block, _ = block.aggregate(
307+
block = block.aggregate(
308308
by_column_ids=self._by_col_ids,
309309
aggregations=aggregations,
310310
column_labels=result_labels,
@@ -339,7 +339,7 @@ def cov(
339339
uniq_orig_columns = utils.combine_indices(labels, pd.Index(range(len(labels))))
340340
result_labels = utils.cross_indices(uniq_orig_columns, uniq_orig_columns)
341341

342-
block, _ = block.aggregate(
342+
block = block.aggregate(
343343
by_column_ids=self._by_col_ids,
344344
aggregations=aggregations,
345345
column_labels=result_labels,
@@ -383,9 +383,9 @@ def first(self, numeric_only: bool = False, min_count: int = -1) -> df.DataFrame
383383
agg_ops.FirstNonNullOp(),
384384
window_spec=window_spec,
385385
)
386-
block, _ = block.aggregate(
387-
self._by_col_ids,
388-
tuple(
386+
block = block.aggregate(
387+
by_column_ids=self._by_col_ids,
388+
aggregations=tuple(
389389
aggs.agg(firsts_id, agg_ops.AnyValueOp()) for firsts_id in firsts_ids
390390
),
391391
dropna=self._dropna,
@@ -405,9 +405,11 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> df.DataFrame:
405405
agg_ops.LastNonNullOp(),
406406
window_spec=window_spec,
407407
)
408-
block, _ = block.aggregate(
409-
self._by_col_ids,
410-
tuple(aggs.agg(lasts_id, agg_ops.AnyValueOp()) for lasts_id in lasts_ids),
408+
block = block.aggregate(
409+
by_column_ids=self._by_col_ids,
410+
aggregations=tuple(
411+
aggs.agg(lasts_id, agg_ops.AnyValueOp()) for lasts_id in lasts_ids
412+
),
411413
dropna=self._dropna,
412414
column_labels=index,
413415
)
@@ -582,7 +584,7 @@ def _agg_func(self, func) -> df.DataFrame:
582584
aggregations = [
583585
aggs.agg(col_id, agg_ops.lookup_agg_func(func)[0]) for col_id in ids
584586
]
585-
agg_block, _ = self._block.aggregate(
587+
agg_block = self._block.aggregate(
586588
by_column_ids=self._by_col_ids,
587589
aggregations=aggregations,
588590
dropna=self._dropna,
@@ -608,7 +610,7 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
608610
aggregations.append(aggs.agg(col_id, f_op))
609611
column_labels.append(label)
610612
function_labels.append(f_label)
611-
agg_block, _ = self._block.aggregate(
613+
agg_block = self._block.aggregate(
612614
by_column_ids=self._by_col_ids,
613615
aggregations=aggregations,
614616
dropna=self._dropna,
@@ -646,7 +648,7 @@ def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
646648
(label, agg_ops.lookup_agg_func(f)[1]) for label in labels for f in func
647649
]
648650

649-
agg_block, _ = self._block.aggregate(
651+
agg_block = self._block.aggregate(
650652
by_column_ids=self._by_col_ids,
651653
aggregations=aggregations,
652654
dropna=self._dropna,
@@ -672,7 +674,7 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
672674
col_id = self._resolve_label(v[0])
673675
aggregations.append(aggs.agg(col_id, agg_ops.lookup_agg_func(v[1])[0]))
674676
column_labels.append(k)
675-
agg_block, _ = self._block.aggregate(
677+
agg_block = self._block.aggregate(
676678
by_column_ids=self._by_col_ids,
677679
aggregations=aggregations,
678680
dropna=self._dropna,
@@ -729,7 +731,7 @@ def _aggregate_all(
729731
) -> df.DataFrame:
730732
aggregated_col_ids, labels = self._aggregated_columns(numeric_only=numeric_only)
731733
aggregations = [aggs.agg(col_id, aggregate_op) for col_id in aggregated_col_ids]
732-
result_block, _ = self._block.aggregate(
734+
result_block = self._block.aggregate(
733735
by_column_ids=self._by_col_ids,
734736
aggregations=aggregations,
735737
column_labels=labels,

bigframes/core/groupby/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def block_groupby_iter(
5555
# are more efficient.
5656
session_aware=False,
5757
)
58-
keys_block, _ = block.aggregate(by_col_ids, dropna=dropna)
58+
keys_block = block.aggregate(by_column_ids=by_col_ids, dropna=dropna)
5959
for chunk in keys_block.to_pandas_batches():
6060
# Convert to MultiIndex to make sure we get tuples,
6161
# even for singular keys.

bigframes/core/groupby/series_group_by.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ def first(self, numeric_only: bool = False, min_count: int = -1) -> series.Serie
222222
agg_ops.FirstNonNullOp(),
223223
window_spec=window_spec,
224224
)
225-
block, _ = block.aggregate(
226-
self._by_col_ids,
225+
block = block.aggregate(
227226
(aggs.agg(firsts_id, agg_ops.AnyValueOp()),),
227+
self._by_col_ids,
228228
dropna=self._dropna,
229229
)
230230
return series.Series(block.with_column_labels([self._value_name]))
@@ -246,9 +246,9 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> series.Series
246246
agg_ops.LastNonNullOp(),
247247
window_spec=window_spec,
248248
)
249-
block, _ = block.aggregate(
250-
self._by_col_ids,
249+
block = block.aggregate(
251250
(aggs.agg(firsts_id, agg_ops.AnyValueOp()),),
251+
self._by_col_ids,
252252
dropna=self._dropna,
253253
)
254254
return series.Series(block.with_column_labels([self._value_name]))
@@ -270,7 +270,7 @@ def agg(self, func=None) -> typing.Union[df.DataFrame, series.Series]:
270270
]
271271
column_names = [agg_ops.lookup_agg_func(f)[1] for f in func]
272272

273-
agg_block, _ = self._block.aggregate(
273+
agg_block = self._block.aggregate(
274274
by_column_ids=self._by_col_ids,
275275
aggregations=aggregations,
276276
dropna=self._dropna,
@@ -413,9 +413,9 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
413413
)
414414

415415
def _aggregate(self, aggregate_op: agg_ops.UnaryAggregateOp) -> series.Series:
416-
result_block, _ = self._block.aggregate(
417-
self._by_col_ids,
416+
result_block = self._block.aggregate(
418417
(aggs.agg(self._value_column, aggregate_op),),
418+
self._by_col_ids,
419419
dropna=self._dropna,
420420
)
421421

0 commit comments

Comments
 (0)