Skip to content

Commit fbb2094

Browse files
perf: Improve axis=1 aggregation performance (#2036)
1 parent 3961637 commit fbb2094

File tree

13 files changed

+200
-44
lines changed

13 files changed

+200
-44
lines changed

bigframes/core/blocks.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,46 +1232,10 @@ def aggregate_all_and_stack(
12321232
index_labels=[None],
12331233
).transpose(original_row_index=pd.Index([None]), single_row_mode=True)
12341234
else: # axis_n == 1
1235-
# using offsets as identity to group on.
1236-
# TODO: Allow to promote identity/total_order columns instead for better perf
1237-
expr_with_offsets, offset_col = self.expr.promote_offsets()
1238-
stacked_expr, (_, value_col_ids, passthrough_cols,) = unpivot(
1239-
expr_with_offsets,
1240-
row_labels=self.column_labels,
1241-
unpivot_columns=[tuple(self.value_columns)],
1242-
passthrough_columns=[*self.index_columns, offset_col],
1243-
)
1244-
# these corresponed to passthrough_columns provided to unpivot
1245-
index_cols = passthrough_cols[:-1]
1246-
og_offset_col = passthrough_cols[-1]
1247-
index_aggregations = [
1248-
(
1249-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)),
1250-
col_id,
1251-
)
1252-
for col_id in index_cols
1253-
]
1254-
# TODO: may need add NullaryAggregation in main_aggregation
1255-
# when agg add support for axis=1, needed for agg("size", axis=1)
1256-
assert isinstance(
1257-
operation, agg_ops.UnaryAggregateOp
1258-
), f"Expected a unary operation, but got {operation}. Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
1259-
main_aggregation = (
1260-
ex.UnaryAggregation(operation, ex.deref(value_col_ids[0])),
1261-
value_col_ids[0],
1262-
)
1263-
# Drop row identity after aggregating over it
1264-
result_expr = stacked_expr.aggregate(
1265-
[*index_aggregations, main_aggregation],
1266-
by_column_ids=[og_offset_col],
1267-
dropna=dropna,
1268-
).drop_columns([og_offset_col])
1269-
return Block(
1270-
result_expr,
1271-
index_columns=index_cols,
1272-
column_labels=[None],
1273-
index_labels=self.index.names,
1274-
)
1235+
as_array = ops.ToArrayOp().as_expr(*(col for col in self.value_columns))
1236+
reduced = ops.ArrayReduceOp(operation).as_expr(as_array)
1237+
block, id = self.project_expr(reduced, None)
1238+
return block.select_column(id)
12751239

12761240
def aggregate_size(
12771241
self,

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _(
165165
) -> ibis_types.NumericValue:
166166
# Will be null if all inputs are null. Pandas defaults to zero sum though.
167167
bq_sum = _apply_window_if_present(column.sum(), window)
168-
return bq_sum.fill_null(ibis_types.literal(0))
168+
return bq_sum.coalesce(ibis_types.literal(0))
169169

170170

171171
@compile_unary_agg.register

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,28 @@ def array_slice_op_impl(x: ibis_types.Value, op: ops.ArraySliceOp):
12011201
return res
12021202

12031203

1204+
@scalar_op_compiler.register_nary_op(ops.ToArrayOp, pass_op=False)
1205+
def to_arry_op_impl(*values: ibis_types.Value):
1206+
do_upcast_bool = any(t.type().is_numeric() for t in values)
1207+
if do_upcast_bool:
1208+
values = tuple(
1209+
val.cast(ibis_dtypes.int64) if val.type().is_boolean() else val
1210+
for val in values
1211+
)
1212+
return ibis_api.array(values)
1213+
1214+
1215+
@scalar_op_compiler.register_unary_op(ops.ArrayReduceOp, pass_op=True)
1216+
def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
1217+
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compilers
1218+
1219+
return typing.cast(ibis_types.ArrayValue, x).reduce(
1220+
lambda arr_vals: agg_compilers.compile_unary_agg(
1221+
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
1222+
)
1223+
)
1224+
1225+
12041226
# JSON Ops
12051227
@scalar_op_compiler.register_binary_op(ops.JSONSet, pass_op=True)
12061228
def json_set_op_impl(x: ibis_types.Value, y: ibis_types.Value, op: ops.JSONSet):

bigframes/core/compile/polars/compiler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import bigframes.dtypes
3232
import bigframes.operations as ops
3333
import bigframes.operations.aggregations as agg_ops
34+
import bigframes.operations.array_ops as arr_ops
3435
import bigframes.operations.bool_ops as bool_ops
3536
import bigframes.operations.comparison_ops as comp_ops
3637
import bigframes.operations.datetime_ops as dt_ops
@@ -353,6 +354,36 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
353354
assert isinstance(op, json_ops.JSONDecode)
354355
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
355356

357+
@compile_op.register(arr_ops.ToArrayOp)
358+
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
359+
return pl.concat_list(*inputs)
360+
361+
@compile_op.register(arr_ops.ArrayReduceOp)
362+
def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr:
363+
# TODO: Unify this with general aggregation compilation?
364+
if isinstance(op.aggregation, agg_ops.MinOp):
365+
return input.list.min()
366+
if isinstance(op.aggregation, agg_ops.MaxOp):
367+
return input.list.max()
368+
if isinstance(op.aggregation, agg_ops.SumOp):
369+
return input.list.sum()
370+
if isinstance(op.aggregation, agg_ops.MeanOp):
371+
return input.list.mean()
372+
if isinstance(op.aggregation, agg_ops.CountOp):
373+
return input.list.len()
374+
if isinstance(op.aggregation, agg_ops.StdOp):
375+
return input.list.std()
376+
if isinstance(op.aggregation, agg_ops.VarOp):
377+
return input.list.var()
378+
if isinstance(op.aggregation, agg_ops.AnyOp):
379+
return input.list.any()
380+
if isinstance(op.aggregation, agg_ops.AllOp):
381+
return input.list.all()
382+
else:
383+
raise NotImplementedError(
384+
f"Haven't implemented array aggregation: {op.aggregation}"
385+
)
386+
356387
@dataclasses.dataclass(frozen=True)
357388
class PolarsAggregateCompiler:
358389
scalar_compiler = PolarsExpressionCompiler()

bigframes/operations/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.array_ops import ArrayIndexOp, ArraySliceOp, ArrayToStringOp
17+
from bigframes.operations.array_ops import (
18+
ArrayIndexOp,
19+
ArrayReduceOp,
20+
ArraySliceOp,
21+
ArrayToStringOp,
22+
ToArrayOp,
23+
)
1824
from bigframes.operations.base_ops import (
1925
BinaryOp,
2026
NaryOp,
@@ -405,4 +411,6 @@
405411
# Numpy ops mapping
406412
"NUMPY_TO_BINOP",
407413
"NUMPY_TO_OP",
414+
"ToArrayOp",
415+
"ArrayReduceOp",
408416
]

bigframes/operations/array_ops.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import functools
1617
import typing
1718

1819
from bigframes import dtypes
19-
from bigframes.operations import base_ops
20+
from bigframes.operations import aggregations, base_ops
2021

2122

2223
@dataclasses.dataclass(frozen=True)
@@ -63,3 +64,27 @@ def output_type(self, *input_types):
6364
return input_type
6465
else:
6566
raise TypeError("Input type must be an array or string-like type.")
67+
68+
69+
class ToArrayOp(base_ops.NaryOp):
70+
name: typing.ClassVar[str] = "array"
71+
72+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
73+
# very permissive, maybe should force caller to do this?
74+
common_type = functools.reduce(
75+
lambda t1, t2: dtypes.coerce_to_common(t1, t2),
76+
input_types,
77+
)
78+
return dtypes.list_type(common_type)
79+
80+
81+
@dataclasses.dataclass(frozen=True)
82+
class ArrayReduceOp(base_ops.UnaryOp):
83+
name: typing.ClassVar[str] = "array_reduce"
84+
aggregation: aggregations.AggregateOp
85+
86+
def output_type(self, *input_types):
87+
input_type = input_types[0]
88+
assert dtypes.is_array_like(input_type)
89+
inner_type = dtypes.get_array_inner_type(input_type)
90+
return self.aggregation.output_type(inner_type)

tests/system/small/engines/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,10 @@ def repeated_data_source(
9090
repeated_pandas_df: pd.DataFrame,
9191
) -> local_data.ManagedArrowTable:
9292
return local_data.ManagedArrowTable.from_pandas(repeated_pandas_df)
93+
94+
95+
@pytest.fixture(scope="module")
96+
def arrays_array_value(
97+
repeated_data_source: local_data.ManagedArrowTable, fake_session: bigframes.Session
98+
):
99+
return ArrayValue.from_managed(repeated_data_source, fake_session)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from bigframes.core import array_value, expression
18+
import bigframes.operations as ops
19+
import bigframes.operations.aggregations as agg_ops
20+
from bigframes.session import polars_executor
21+
from bigframes.testing.engine_utils import assert_equivalence_execution
22+
23+
pytest.importorskip("polars")
24+
25+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
26+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
27+
28+
29+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
30+
def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine):
31+
# Bigquery won't allow you to materialize arrays with null, so use non-nullable
32+
int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0))
33+
bool_col_non_null = ops.coalesce_op.as_expr("bool_col", expression.const(False))
34+
float_col_non_null = ops.coalesce_op.as_expr("float64_col", expression.const(0.0))
35+
string_col_non_null = ops.coalesce_op.as_expr("string_col", expression.const(""))
36+
37+
arr, _ = scalars_array_value.compute_values(
38+
[
39+
ops.ToArrayOp().as_expr(int64_non_null),
40+
ops.ToArrayOp().as_expr(
41+
int64_non_null, bool_col_non_null, float_col_non_null
42+
),
43+
ops.ToArrayOp().as_expr(string_col_non_null, string_col_non_null),
44+
]
45+
)
46+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
47+
48+
49+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
50+
def test_engines_array_reduce_op(arrays_array_value: array_value.ArrayValue, engine):
51+
arr, _ = arrays_array_value.compute_values(
52+
[
53+
ops.ArrayReduceOp(agg_ops.SumOp()).as_expr("float_list_col"),
54+
ops.ArrayReduceOp(agg_ops.StdOp()).as_expr("float_list_col"),
55+
ops.ArrayReduceOp(agg_ops.MaxOp()).as_expr("date_list_col"),
56+
ops.ArrayReduceOp(agg_ops.CountOp()).as_expr("string_list_col"),
57+
ops.ArrayReduceOp(agg_ops.AnyOp()).as_expr("bool_list_col"),
58+
]
59+
)
60+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,9 @@ def visit_ArrayFilter(self, op, *, arg, body, param):
699699
def visit_ArrayMap(self, op, *, arg, body, param):
700700
return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param)))
701701

702+
def visit_ArrayReduce(self, op, *, arg, body, param):
703+
return sg.select(body).from_(self._unnest(arg, as_=param)).subquery()
704+
702705
def visit_ArrayZip(self, op, *, arg):
703706
lengths = [self.f.array_length(arr) - 1 for arr in arg]
704707
idx = sg.to_identifier(util.gen_name("bq_arr_idx"))

third_party/bigframes_vendored/ibis/expr/operations/arrays.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ def dtype(self) -> dt.DataType:
105105
return dt.Array(self.body.dtype)
106106

107107

108+
@public
109+
class ArrayReduce(Value):
110+
"""Apply a function to every element of an array."""
111+
112+
arg: Value[dt.Array]
113+
body: Value
114+
param: str
115+
116+
shape = rlz.shape_like("arg")
117+
118+
@attribute
119+
def dtype(self) -> dt.DataType:
120+
return self.body.dtype
121+
122+
108123
@public
109124
class ArrayFilter(Value):
110125
"""Filter array elements with a function."""

0 commit comments

Comments
 (0)