Skip to content

Commit 3b53092

Browse files
refactor: Generalize Block.aggregate to non-unary aggregates (#1304)
1 parent c8e7b8f commit 3b53092

File tree

4 files changed

+84
-71
lines changed

4 files changed

+84
-71
lines changed

bigframes/core/block_transforms.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,16 @@ def quantile(
129129
window_spec=window,
130130
)
131131
quantile_cols.append(quantile_col)
132-
block, results = block.aggregate(
132+
block, _ = block.aggregate(
133133
grouping_column_ids,
134-
tuple((col, agg_ops.AnyValueOp()) for col in quantile_cols),
134+
tuple(
135+
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
136+
for col in quantile_cols
137+
),
138+
column_labels=pd.Index(labels),
135139
dropna=dropna,
136140
)
137-
return block.select_columns(results).with_column_labels(labels)
141+
return block
138142

139143

140144
def interpolate(block: blocks.Block, method: str = "linear") -> blocks.Block:
@@ -355,7 +359,7 @@ def value_counts(
355359
block, dummy = block.create_constant(1)
356360
block, agg_ids = block.aggregate(
357361
by_column_ids=columns,
358-
aggregations=[(dummy, agg_ops.count_op)],
362+
aggregations=[ex.UnaryAggregation(agg_ops.count_op, ex.deref(dummy))],
359363
dropna=dropna,
360364
)
361365
count_id = agg_ids[0]
@@ -589,9 +593,18 @@ def skew(
589593
# counts, moment3 for each column
590594
aggregations = []
591595
for i, col in enumerate(original_columns):
592-
count_agg = (col, agg_ops.count_op)
593-
moment3_agg = (delta3_ids[i], agg_ops.mean_op)
594-
variance_agg = (col, agg_ops.PopVarOp())
596+
count_agg = ex.UnaryAggregation(
597+
agg_ops.count_op,
598+
ex.deref(col),
599+
)
600+
moment3_agg = ex.UnaryAggregation(
601+
agg_ops.mean_op,
602+
ex.deref(delta3_ids[i]),
603+
)
604+
variance_agg = ex.UnaryAggregation(
605+
agg_ops.PopVarOp(),
606+
ex.deref(col),
607+
)
595608
aggregations.extend([count_agg, moment3_agg, variance_agg])
596609

597610
block, agg_ids = block.aggregate(
@@ -631,9 +644,9 @@ def kurt(
631644
# counts, moment4 for each column
632645
aggregations = []
633646
for i, col in enumerate(original_columns):
634-
count_agg = (col, agg_ops.count_op)
635-
moment4_agg = (delta4_ids[i], agg_ops.mean_op)
636-
variance_agg = (col, agg_ops.PopVarOp())
647+
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(col))
648+
moment4_agg = ex.UnaryAggregation(agg_ops.mean_op, ex.deref(delta4_ids[i]))
649+
variance_agg = ex.UnaryAggregation(agg_ops.PopVarOp(), ex.deref(col))
637650
aggregations.extend([count_agg, moment4_agg, variance_agg])
638651

639652
block, agg_ids = block.aggregate(

bigframes/core/blocks.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,38 +1174,31 @@ def remap_f(x):
11741174
def aggregate(
11751175
self,
11761176
by_column_ids: typing.Sequence[str] = (),
1177-
aggregations: typing.Sequence[
1178-
typing.Tuple[
1179-
str, typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
1180-
]
1181-
] = (),
1177+
aggregations: typing.Sequence[ex.Aggregation] = (),
1178+
column_labels: Optional[pd.Index] = None,
11821179
*,
11831180
dropna: bool = True,
11841181
) -> typing.Tuple[Block, typing.Sequence[str]]:
11851182
"""
1186-
Apply aggregations to the block. Callers responsible for setting index column(s) after.
1183+
Apply aggregations to the block.
11871184
Arguments:
11881185
by_column_id: column id of the aggregation key, this is preserved through the transform and used as index.
11891186
aggregations: input_column_id, operation tuples
1190-
as_index: if True, grouping keys will be index columns in result, otherwise they will be non-index columns.
11911187
dropna: whether null keys should be dropped
11921188
"""
1189+
if column_labels is None:
1190+
column_labels = pd.Index(range(len(aggregations)))
1191+
11931192
agg_specs = [
11941193
(
1195-
ex.UnaryAggregation(operation, ex.deref(input_id))
1196-
if isinstance(operation, agg_ops.UnaryAggregateOp)
1197-
else ex.NullaryAggregation(operation),
1194+
aggregation,
11981195
guid.generate_guid(),
11991196
)
1200-
for input_id, operation in aggregations
1197+
for aggregation in aggregations
12011198
]
12021199
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
12031200
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)
12041201

1205-
aggregate_labels = self._get_labels_for_columns(
1206-
[agg[0] for agg in aggregations]
1207-
)
1208-
12091202
names: typing.List[Label] = []
12101203
if len(by_column_ids) == 0:
12111204
result_expr, label_id = result_expr.create_constant(0, pd.Int64Dtype())
@@ -1223,7 +1216,7 @@ def aggregate(
12231216
Block(
12241217
result_expr,
12251218
index_columns=index_columns,
1226-
column_labels=aggregate_labels,
1219+
column_labels=column_labels,
12271220
index_labels=names,
12281221
),
12291222
output_col_ids,
@@ -1561,7 +1554,10 @@ def pivot(
15611554
column_ids.append(masked_id)
15621555

15631556
block = block.select_columns(column_ids)
1564-
aggregations = [(col_id, agg_ops.AnyValueOp()) for col_id in column_ids]
1557+
aggregations = [
1558+
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id))
1559+
for col_id in column_ids
1560+
]
15651561
result_block, _ = block.aggregate(
15661562
by_column_ids=self.index_columns,
15671563
aggregations=aggregations,

bigframes/core/groupby/__init__.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Sequence, Union
18+
from typing import Sequence, Tuple, Union
1919

2020
import bigframes_vendored.constants as constants
2121
import bigframes_vendored.pandas.core.groupby as vendored_pandas_groupby
@@ -26,6 +26,7 @@
2626
import bigframes.core as core
2727
import bigframes.core.block_transforms as block_ops
2828
import bigframes.core.blocks as blocks
29+
import bigframes.core.expression
2930
import bigframes.core.ordering as order
3031
import bigframes.core.utils as utils
3132
import bigframes.core.validations as validations
@@ -334,24 +335,19 @@ def agg(self, func=None, **kwargs) -> typing.Union[df.DataFrame, series.Series]:
334335
return self._agg_named(**kwargs)
335336

336337
def _agg_string(self, func: str) -> df.DataFrame:
337-
aggregations = [
338-
(col_id, agg_ops.lookup_agg_func(func))
339-
for col_id in self._aggregated_columns()
340-
]
338+
ids, labels = self._aggregated_columns()
339+
aggregations = [agg(col_id, agg_ops.lookup_agg_func(func)) for col_id in ids]
341340
agg_block, _ = self._block.aggregate(
342341
by_column_ids=self._by_col_ids,
343342
aggregations=aggregations,
344343
dropna=self._dropna,
344+
column_labels=labels,
345345
)
346346
dataframe = df.DataFrame(agg_block)
347347
return dataframe if self._as_index else self._convert_index(dataframe)
348348

349349
def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
350-
aggregations: typing.List[
351-
typing.Tuple[
352-
str, typing.Union[agg_ops.UnaryAggregateOp, agg_ops.NullaryAggregateOp]
353-
]
354-
] = []
350+
aggregations: typing.List[bigframes.core.expression.Aggregation] = []
355351
column_labels = []
356352

357353
want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values())
@@ -362,7 +358,7 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
362358
funcs_for_id if utils.is_list_like(funcs_for_id) else [funcs_for_id]
363359
)
364360
for f in func_list:
365-
aggregations.append((col_id, agg_ops.lookup_agg_func(f)))
361+
aggregations.append(agg(col_id, agg_ops.lookup_agg_func(f)))
366362
column_labels.append(label)
367363
agg_block, _ = self._block.aggregate(
368364
by_column_ids=self._by_col_ids,
@@ -373,7 +369,10 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
373369
agg_block = agg_block.with_column_labels(
374370
utils.combine_indices(
375371
pd.Index(column_labels),
376-
pd.Index(agg[1].name for agg in aggregations),
372+
pd.Index(
373+
typing.cast(agg_ops.AggregateOp, agg.op).name
374+
for agg in aggregations
375+
),
377376
)
378377
)
379378
else:
@@ -382,34 +381,21 @@ def _agg_dict(self, func: typing.Mapping) -> df.DataFrame:
382381
return dataframe if self._as_index else self._convert_index(dataframe)
383382

384383
def _agg_list(self, func: typing.Sequence) -> df.DataFrame:
384+
ids, labels = self._aggregated_columns()
385385
aggregations = [
386-
(col_id, agg_ops.lookup_agg_func(f))
387-
for col_id in self._aggregated_columns()
388-
for f in func
386+
agg(col_id, agg_ops.lookup_agg_func(f)) for col_id in ids for f in func
389387
]
390388

391389
if self._block.column_labels.nlevels > 1:
392390
# Restructure MultiIndex for proper format: (idx1, idx2, func)
393391
# rather than ((idx1, idx2), func).
394-
aggregated_columns = pd.MultiIndex.from_tuples(
395-
[
396-
self._block.col_id_to_label[col_id]
397-
for col_id in self._aggregated_columns()
398-
],
399-
names=[*self._block.column_labels.names],
400-
).to_frame(index=False)
401-
402392
column_labels = [
403-
tuple(col_id) + (f,)
404-
for col_id in aggregated_columns.to_numpy()
405-
for f in func
406-
]
407-
else:
408-
column_labels = [
409-
(self._block.col_id_to_label[col_id], f)
410-
for col_id in self._aggregated_columns()
393+
tuple(label) + (f,)
394+
for label in labels.to_frame(index=False).to_numpy()
411395
for f in func
412396
]
397+
else: # Single-level index
398+
column_labels = [(label, f) for label in labels for f in func]
413399

414400
agg_block, _ = self._block.aggregate(
415401
by_column_ids=self._by_col_ids,
@@ -435,7 +421,7 @@ def _agg_named(self, **kwargs) -> df.DataFrame:
435421
if not isinstance(v, tuple) or (len(v) != 2):
436422
raise TypeError("kwargs values must be 2-tuples of column, aggfunc")
437423
col_id = self._resolve_label(v[0])
438-
aggregations.append((col_id, agg_ops.lookup_agg_func(v[1])))
424+
aggregations.append(agg(col_id, agg_ops.lookup_agg_func(v[1])))
439425
column_labels.append(k)
440426
agg_block, _ = self._block.aggregate(
441427
by_column_ids=self._by_col_ids,
@@ -470,15 +456,19 @@ def _raise_on_non_numeric(self, op: str):
470456
)
471457
return self
472458

473-
def _aggregated_columns(self, numeric_only: bool = False) -> typing.Sequence[str]:
459+
def _aggregated_columns(
460+
self, numeric_only: bool = False
461+
) -> Tuple[typing.Sequence[str], pd.Index]:
474462
valid_agg_cols: list[str] = []
475-
for col_id in self._selected_cols:
463+
offsets: list[int] = []
464+
for i, col_id in enumerate(self._block.value_columns):
476465
is_numeric = (
477466
self._column_type(col_id) in dtypes.NUMERIC_BIGFRAMES_TYPES_PERMISSIVE
478467
)
479-
if is_numeric or not numeric_only:
468+
if (col_id in self._selected_cols) and (is_numeric or not numeric_only):
469+
offsets.append(i)
480470
valid_agg_cols.append(col_id)
481-
return valid_agg_cols
471+
return valid_agg_cols, self._block.column_labels.take(offsets)
482472

483473
def _column_type(self, col_id: str) -> dtypes.Dtype:
484474
col_offset = self._block.value_columns.index(col_id)
@@ -488,11 +478,12 @@ def _column_type(self, col_id: str) -> dtypes.Dtype:
488478
def _aggregate_all(
489479
self, aggregate_op: agg_ops.UnaryAggregateOp, numeric_only: bool = False
490480
) -> df.DataFrame:
491-
aggregated_col_ids = self._aggregated_columns(numeric_only=numeric_only)
492-
aggregations = [(col_id, aggregate_op) for col_id in aggregated_col_ids]
481+
aggregated_col_ids, labels = self._aggregated_columns(numeric_only=numeric_only)
482+
aggregations = [agg(col_id, aggregate_op) for col_id in aggregated_col_ids]
493483
result_block, _ = self._block.aggregate(
494484
by_column_ids=self._by_col_ids,
495485
aggregations=aggregations,
486+
column_labels=labels,
496487
dropna=self._dropna,
497488
)
498489
dataframe = df.DataFrame(result_block)
@@ -508,7 +499,7 @@ def _apply_window_op(
508499
window_spec = window or window_specs.cumulative_rows(
509500
grouping_keys=tuple(self._by_col_ids)
510501
)
511-
columns = self._aggregated_columns(numeric_only=numeric_only)
502+
columns, _ = self._aggregated_columns(numeric_only=numeric_only)
512503
block, result_ids = self._block.multi_apply_window_op(
513504
columns, op, window_spec=window_spec
514505
)
@@ -639,11 +630,11 @@ def prod(self, *args) -> series.Series:
639630
def agg(self, func=None) -> typing.Union[df.DataFrame, series.Series]:
640631
column_names: list[str] = []
641632
if isinstance(func, str):
642-
aggregations = [(self._value_column, agg_ops.lookup_agg_func(func))]
633+
aggregations = [agg(self._value_column, agg_ops.lookup_agg_func(func))]
643634
column_names = [func]
644635
elif utils.is_list_like(func):
645636
aggregations = [
646-
(self._value_column, agg_ops.lookup_agg_func(f)) for f in func
637+
agg(self._value_column, agg_ops.lookup_agg_func(f)) for f in func
647638
]
648639
column_names = list(func)
649640
else:
@@ -756,7 +747,7 @@ def expanding(self, min_periods: int = 1) -> windows.Window:
756747
def _aggregate(self, aggregate_op: agg_ops.UnaryAggregateOp) -> series.Series:
757748
result_block, _ = self._block.aggregate(
758749
self._by_col_ids,
759-
((self._value_column, aggregate_op),),
750+
(agg(self._value_column, aggregate_op),),
760751
dropna=self._dropna,
761752
)
762753

@@ -781,3 +772,13 @@ def _apply_window_op(
781772
window_spec=window_spec,
782773
)
783774
return series.Series(block.select_column(result_id))
775+
776+
777+
def agg(input: str, op: agg_ops.AggregateOp) -> bigframes.core.expression.Aggregation:
778+
if isinstance(op, agg_ops.UnaryAggregateOp):
779+
return bigframes.core.expression.UnaryAggregation(
780+
op, bigframes.core.expression.deref(input)
781+
)
782+
else:
783+
assert isinstance(op, agg_ops.NullaryAggregateOp)
784+
return bigframes.core.expression.NullaryAggregation(op)

bigframes/series.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,9 @@ def mode(self) -> Series:
10621062
# Approach: Count each value, return each value for which count(x) == max(counts))
10631063
block, agg_ids = block.aggregate(
10641064
by_column_ids=[self._value_column],
1065-
aggregations=((self._value_column, agg_ops.count_op),),
1065+
aggregations=(
1066+
ex.UnaryAggregation(agg_ops.count_op, ex.deref(self._value_column)),
1067+
),
10661068
)
10671069
value_count_col_id = agg_ids[0]
10681070
block, max_value_count_col_id = block.apply_window_op(
@@ -1675,7 +1677,8 @@ def unique(self, keep_order=True) -> Series:
16751677
return self.drop_duplicates()
16761678
block, result = self._block.aggregate(
16771679
[self._value_column],
1678-
[(self._value_column, agg_ops.AnyValueOp())],
1680+
[ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(self._value_column))],
1681+
column_labels=self._block.column_labels,
16791682
dropna=False,
16801683
)
16811684
return Series(block.select_columns(result).reset_index())

0 commit comments

Comments
 (0)