Skip to content

Commit 18a3c57

Browse files
perf: Improve axis=1 aggregation performance
1 parent 935af10 commit 18a3c57

File tree

10 files changed

+124
-44
lines changed

10 files changed

+124
-44
lines changed

bigframes/core/blocks.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,46 +1232,10 @@ def aggregate_all_and_stack(
12321232
index_labels=[None],
12331233
).transpose(original_row_index=pd.Index([None]), single_row_mode=True)
12341234
else: # axis_n == 1
1235-
# using offsets as identity to group on.
1236-
# TODO: Allow to promote identity/total_order columns instead for better perf
1237-
expr_with_offsets, offset_col = self.expr.promote_offsets()
1238-
stacked_expr, (_, value_col_ids, passthrough_cols,) = unpivot(
1239-
expr_with_offsets,
1240-
row_labels=self.column_labels,
1241-
unpivot_columns=[tuple(self.value_columns)],
1242-
passthrough_columns=[*self.index_columns, offset_col],
1243-
)
1244-
# these corresponed to passthrough_columns provided to unpivot
1245-
index_cols = passthrough_cols[:-1]
1246-
og_offset_col = passthrough_cols[-1]
1247-
index_aggregations = [
1248-
(
1249-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)),
1250-
col_id,
1251-
)
1252-
for col_id in index_cols
1253-
]
1254-
# TODO: may need add NullaryAggregation in main_aggregation
1255-
# when agg add support for axis=1, needed for agg("size", axis=1)
1256-
assert isinstance(
1257-
operation, agg_ops.UnaryAggregateOp
1258-
), f"Expected a unary operation, but got {operation}. Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
1259-
main_aggregation = (
1260-
ex.UnaryAggregation(operation, ex.deref(value_col_ids[0])),
1261-
value_col_ids[0],
1262-
)
1263-
# Drop row identity after aggregating over it
1264-
result_expr = stacked_expr.aggregate(
1265-
[*index_aggregations, main_aggregation],
1266-
by_column_ids=[og_offset_col],
1267-
dropna=dropna,
1268-
).drop_columns([og_offset_col])
1269-
return Block(
1270-
result_expr,
1271-
index_columns=index_cols,
1272-
column_labels=[None],
1273-
index_labels=self.index.names,
1274-
)
1235+
as_array = ops.ToArrayOp().as_expr(*(col for col in self.value_columns))
1236+
reduced = ops.ArrayReduceOp(operation).as_expr(as_array)
1237+
block, id = self.project_expr(reduced, None)
1238+
return block.select_column(id)
12751239

12761240
def aggregate_size(
12771241
self,

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _(
165165
) -> ibis_types.NumericValue:
166166
# Will be null if all inputs are null. Pandas defaults to zero sum though.
167167
bq_sum = _apply_window_if_present(column.sum(), window)
168-
return bq_sum.fill_null(ibis_types.literal(0))
168+
return bq_sum.coalesce(ibis_types.literal(0))
169169

170170

171171
@compile_unary_agg.register

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,26 @@ def array_slice_op_impl(x: ibis_types.Value, op: ops.ArraySliceOp):
12011201
return res
12021202

12031203

1204+
@scalar_op_compiler.register_nary_op(ops.ToArrayOp, pass_op=False)
1205+
def to_arry_op_impl(*values: ibis_types.Value):
1206+
do_upcast_bool = any(t.type().is_numeric() for t in values)
1207+
if do_upcast_bool:
1208+
values = tuple(
1209+
val.cast(ibis_dtypes.int64) if val.type().is_boolean() else val
1210+
for val in values
1211+
)
1212+
return ibis_api.array(values)
1213+
1214+
1215+
@scalar_op_compiler.register_unary_op(ops.ArrayReduceOp, pass_op=True)
1216+
def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
1217+
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compilers
1218+
1219+
return typing.cast(ibis_types.ArrayValue, x).reduce(
1220+
lambda arr_vals: agg_compilers.compile_unary_agg(op.aggregation, arr_vals)
1221+
)
1222+
1223+
12041224
# JSON Ops
12051225
@scalar_op_compiler.register_binary_op(ops.JSONSet, pass_op=True)
12061226
def json_set_op_impl(x: ibis_types.Value, y: ibis_types.Value, op: ops.JSONSet):

bigframes/core/compile/polars/compiler.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import bigframes.dtypes
3232
import bigframes.operations as ops
3333
import bigframes.operations.aggregations as agg_ops
34+
import bigframes.operations.array_ops as arr_ops
3435
import bigframes.operations.bool_ops as bool_ops
3536
import bigframes.operations.comparison_ops as comp_ops
3637
import bigframes.operations.datetime_ops as dt_ops
@@ -353,6 +354,32 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
353354
assert isinstance(op, json_ops.JSONDecode)
354355
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
355356

357+
@compile_op.register(arr_ops.ToArrayOp)
358+
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
359+
return pl.concat_list(*inputs)
360+
361+
@compile_op.register(arr_ops.ArrayReduceOp)
362+
def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr:
363+
# TODO: Unify this with general aggregation compilation?
364+
if isinstance(op.aggregation, agg_ops.MinOp):
365+
return input.list.min()
366+
if isinstance(op.aggregation, agg_ops.MaxOp):
367+
return input.list.max()
368+
if isinstance(op.aggregation, agg_ops.SumOp):
369+
return input.list.sum()
370+
if isinstance(op.aggregation, agg_ops.MeanOp):
371+
return input.list.mean()
372+
if isinstance(op.aggregation, agg_ops.CountOp):
373+
return input.list.len()
374+
if isinstance(op.aggregation, agg_ops.StdOp):
375+
return input.list.std()
376+
if isinstance(op.aggregation, agg_ops.VarOp):
377+
return input.list.var()
378+
else:
379+
raise NotImplementedError(
380+
f"Haven't implemented array aggregation: {op.aggregation}"
381+
)
382+
356383
@dataclasses.dataclass(frozen=True)
357384
class PolarsAggregateCompiler:
358385
scalar_compiler = PolarsExpressionCompiler()

bigframes/operations/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.array_ops import ArrayIndexOp, ArraySliceOp, ArrayToStringOp
17+
from bigframes.operations.array_ops import (
18+
ArrayIndexOp,
19+
ArrayReduceOp,
20+
ArraySliceOp,
21+
ArrayToStringOp,
22+
ToArrayOp,
23+
)
1824
from bigframes.operations.base_ops import (
1925
BinaryOp,
2026
NaryOp,
@@ -405,4 +411,6 @@
405411
# Numpy ops mapping
406412
"NUMPY_TO_BINOP",
407413
"NUMPY_TO_OP",
414+
"ToArrayOp",
415+
"ArrayReduceOp",
408416
]

bigframes/operations/array_ops.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import functools
1617
import typing
1718

1819
from bigframes import dtypes
19-
from bigframes.operations import base_ops
20+
from bigframes.operations import aggregations, base_ops
2021

2122

2223
@dataclasses.dataclass(frozen=True)
@@ -63,3 +64,27 @@ def output_type(self, *input_types):
6364
return input_type
6465
else:
6566
raise TypeError("Input type must be an array or string-like type.")
67+
68+
69+
class ToArrayOp(base_ops.NaryOp):
70+
name: typing.ClassVar[str] = "array"
71+
72+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
73+
# very permissive, maybe should force caller to do this?
74+
common_type = functools.reduce(
75+
lambda t1, t2: dtypes.coerce_to_common(t1, t2),
76+
input_types,
77+
)
78+
return dtypes.list_type(common_type)
79+
80+
81+
@dataclasses.dataclass(frozen=True)
82+
class ArrayReduceOp(base_ops.UnaryOp):
83+
name: typing.ClassVar[str] = "array_reduce"
84+
aggregation: aggregations.AggregateOp
85+
86+
def output_type(self, *input_types):
87+
input_type = input_types[0]
88+
assert dtypes.is_array_like(input_type)
89+
inner_type = dtypes.get_array_inner_type(input_type)
90+
return self.aggregation.output_type(inner_type)

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,9 @@ def visit_ArrayFilter(self, op, *, arg, body, param):
699699
def visit_ArrayMap(self, op, *, arg, body, param):
700700
return self.f.array(sg.select(body).from_(self._unnest(arg, as_=param)))
701701

702+
def visit_ArrayReduce(self, op, *, arg, body, param):
703+
return sg.select(body).from_(self._unnest(arg, as_=param)).subquery()
704+
702705
def visit_ArrayZip(self, op, *, arg):
703706
lengths = [self.f.array_length(arr) - 1 for arr in arg]
704707
idx = sg.to_identifier(util.gen_name("bq_arr_idx"))

third_party/bigframes_vendored/ibis/expr/operations/arrays.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ def dtype(self) -> dt.DataType:
105105
return dt.Array(self.body.dtype)
106106

107107

108+
@public
109+
class ArrayReduce(Value):
110+
"""Apply a function to every element of an array."""
111+
112+
arg: Value[dt.Array]
113+
body: Value
114+
param: str
115+
116+
shape = rlz.shape_like("arg")
117+
118+
@attribute
119+
def dtype(self) -> dt.DataType:
120+
return self.body.dtype
121+
122+
108123
@public
109124
class ArrayFilter(Value):
110125
"""Filter array elements with a function."""

third_party/bigframes_vendored/ibis/expr/rewrites.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def rewrite_project_input(value, relation):
252252
# relation
253253
return value.replace(
254254
project_wrap_analytic | project_wrap_reduction,
255-
filter=p.Value & ~p.WindowFunction,
255+
filter=p.Value & ~p.WindowFunction & ~p.ArrayReduce,
256256
context={"rel": relation},
257257
)
258258

third_party/bigframes_vendored/ibis/expr/types/arrays.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,24 @@ def map(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
486486
body = resolve(parameter.to_expr())
487487
return ops.ArrayMap(self, param=parameter.param, body=body).to_expr()
488488

489+
def reduce(self, func: Deferred | Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
490+
if isinstance(func, Deferred):
491+
name = "_"
492+
resolve = func.resolve
493+
elif callable(func):
494+
name = next(iter(inspect.signature(func).parameters.keys()))
495+
resolve = func
496+
else:
497+
raise TypeError(
498+
f"`func` must be a Deferred or Callable, got `{type(func).__name__}`"
499+
)
500+
501+
parameter = ops.Argument(
502+
name=name, shape=self.op().shape, dtype=self.type().value_type
503+
)
504+
body = resolve(parameter.to_expr())
505+
return ops.ArrayReduce(self, param=parameter.param, body=body).to_expr()
506+
489507
def filter(
490508
self, predicate: Deferred | Callable[[ir.Value], bool | ir.BooleanValue]
491509
) -> ir.ArrayValue:

0 commit comments

Comments
 (0)