Skip to content

Commit 774e56b

Browse files
refactor: Simplify expression generation for some block ops (#1298)
1 parent 0318764 commit 774e56b

File tree

3 files changed

+56
-52
lines changed

3 files changed

+56
-52
lines changed

bigframes/core/block_transforms.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,16 @@ def equals(block1: blocks.Block, block2: blocks.Block) -> bool:
4343

4444
joined_block, (lmap, rmap) = block1.join(block2, how="outer")
4545

46-
equality_ids = []
46+
exprs = []
4747
for lcol, rcol in zip(block1.value_columns, block2.value_columns):
48-
lcolmapped = lmap[lcol]
49-
rcolmapped = rmap[rcol]
50-
joined_block, result_id = joined_block.project_expr(
48+
exprs.append(
5149
ops.fillna_op.as_expr(
52-
ops.eq_null_match_op.as_expr(lcolmapped, rcolmapped), ex.const(False)
50+
ops.eq_null_match_op.as_expr(lmap[lcol], rmap[rcol]), ex.const(False)
5351
)
5452
)
55-
equality_ids.append(result_id)
5653

57-
joined_block = joined_block.select_columns(equality_ids).with_column_labels(
58-
list(range(len(equality_ids)))
54+
joined_block = joined_block.project_exprs(
55+
exprs, labels=list(range(len(exprs))), drop=True
5956
)
6057
stacked_block = joined_block.stack()
6158
result = stacked_block.get_stat(stacked_block.value_columns[0], agg_ops.all_op)
@@ -395,12 +392,12 @@ def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block:
395392
block, shift_columns = block.multi_apply_window_op(
396393
original_columns, agg_ops.ShiftOp(periods), window_spec=window_spec
397394
)
398-
result_ids = []
395+
exprs = []
399396
for original_col, shifted_col in zip(original_columns, shift_columns):
400-
block, change_id = block.apply_binary_op(original_col, shifted_col, ops.sub_op)
401-
block, pct_change_id = block.apply_binary_op(change_id, shifted_col, ops.div_op)
402-
result_ids.append(pct_change_id)
403-
return block.select_columns(result_ids).with_column_labels(column_labels)
397+
change_expr = ops.sub_op.as_expr(original_col, shifted_col)
398+
pct_change_expr = ops.div_op.as_expr(change_expr, shifted_col)
399+
exprs.append(pct_change_expr)
400+
return block.project_exprs(exprs, labels=column_labels, drop=True)
404401

405402

406403
def rank(
@@ -470,16 +467,23 @@ def rank(
470467
# Step 3: post processing: mask null values and cast to float
471468
if method in ["min", "max", "first", "dense"]:
472469
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
473-
block = block.multi_apply_unary_op(
474-
rownum_col_ids, ops.AsTypeOp(pd.Float64Dtype())
470+
return (
471+
block.select_columns(rownum_col_ids)
472+
.multi_apply_unary_op(ops.AsTypeOp(pd.Float64Dtype()))
473+
.with_column_labels(labels)
475474
)
476475
if na_option == "keep":
477476
# For na_option "keep", null inputs must produce null outputs
477+
exprs = []
478478
for i in range(len(columns)):
479-
block, null_const = block.create_constant(pd.NA, dtype=pd.Float64Dtype())
480-
block, rownum_col_ids[i] = block.apply_ternary_op(
481-
null_const, nullity_col_ids[i], rownum_col_ids[i], ops.where_op
479+
exprs.append(
480+
ops.where_op.as_expr(
481+
ex.const(pd.NA, dtype=pd.Float64Dtype()),
482+
nullity_col_ids[i],
483+
rownum_col_ids[i],
484+
)
482485
)
486+
return block.project_exprs(exprs, labels=labels, drop=True)
483487

484488
return block.select_columns(rownum_col_ids).with_column_labels(labels)
485489

bigframes/core/blocks.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,6 @@ def multi_apply_window_op(
897897

898898
def multi_apply_unary_op(
899899
self,
900-
columns: typing.Sequence[str],
901900
op: Union[ops.UnaryOp, ex.Expression],
902901
) -> Block:
903902
if isinstance(op, ops.UnaryOp):
@@ -911,27 +910,37 @@ def multi_apply_unary_op(
911910

912911
block = self
913912

914-
result_ids = []
915-
for col_id in columns:
916-
label = self.col_id_to_label[col_id]
917-
block, result_id = block.project_expr(
918-
expr.bind_variables({input_varname: ex.deref(col_id)}),
919-
label=label,
920-
)
921-
block = block.copy_values(result_id, col_id)
922-
result_ids.append(result_id)
923-
block = block.drop_columns(result_ids)
913+
exprs = [
914+
expr.bind_variables({input_varname: ex.deref(col_id)})
915+
for col_id in self.value_columns
916+
]
917+
block = self.project_exprs(exprs, labels=self.column_labels, drop=True)
918+
924919
# Special case, we can preserve transpose cache for full-frame unary ops
925-
if (self._transpose_cache is not None) and set(self.value_columns) == set(
926-
columns
927-
):
928-
transpose_columns = self._transpose_cache.value_columns
929-
new_transpose_cache = self._transpose_cache.multi_apply_unary_op(
930-
transpose_columns, op
931-
)
920+
if self._transpose_cache is not None:
921+
new_transpose_cache = self._transpose_cache.multi_apply_unary_op(op)
932922
block = block.with_transpose_cache(new_transpose_cache)
933923
return block
934924

925+
def project_exprs(
926+
self,
927+
exprs: Sequence[ex.Expression],
928+
labels: Union[Sequence[Label], pd.Index],
929+
drop=False,
930+
) -> Block:
931+
new_array, _ = self.expr.compute_values(exprs)
932+
if drop:
933+
new_array = new_array.drop_columns(self.value_columns)
934+
935+
return Block(
936+
new_array,
937+
index_columns=self.index_columns,
938+
column_labels=labels
939+
if drop
940+
else self.column_labels.append(pd.Index(labels)),
941+
index_labels=self._index_labels,
942+
)
943+
935944
def apply_window_op(
936945
self,
937946
column: str,
@@ -2279,18 +2288,15 @@ def _apply_binop(
22792288
labels: pd.Index,
22802289
reverse: bool = False,
22812290
) -> Block:
2282-
block = self
2283-
binop_result_ids = []
2291+
exprs = []
22842292
for left_input, right_input in inputs:
2285-
expr = (
2293+
exprs.append(
22862294
op.as_expr(right_input, left_input)
22872295
if reverse
22882296
else op.as_expr(left_input, right_input)
22892297
)
2290-
block, result_col_id = block.project_expr(expr)
2291-
binop_result_ids.append(result_col_id)
22922298

2293-
return block.select_columns(binop_result_ids).with_column_labels(labels)
2299+
return self.project_exprs(exprs, labels=labels, drop=True)
22942300

22952301
def join(
22962302
self,

bigframes/dataframe.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ def __init__(
179179
if columns:
180180
block = block.select_columns(list(columns)) # type:ignore
181181
if dtype:
182-
block = block.multi_apply_unary_op(
183-
block.value_columns, ops.AsTypeOp(to_type=dtype)
184-
)
182+
block = block.multi_apply_unary_op(ops.AsTypeOp(to_type=dtype))
185183
self._block = block
186184

187185
else:
@@ -845,9 +843,7 @@ def _apply_scalar_binop(
845843
left_input=ex.free_var("var1"),
846844
right_input=ex.const(other),
847845
)
848-
return DataFrame(
849-
self._block.multi_apply_unary_op(self._block.value_columns, expr)
850-
)
846+
return DataFrame(self._block.multi_apply_unary_op(expr))
851847

852848
def _apply_series_binop_axis_0(
853849
self,
@@ -2400,9 +2396,7 @@ def dropna(
24002396
result = result.reset_index()
24012397
return DataFrame(result)
24022398
else:
2403-
isnull_block = self._block.multi_apply_unary_op(
2404-
self._block.value_columns, ops.isnull_op
2405-
)
2399+
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
24062400
if how == "any":
24072401
null_locations = DataFrame(isnull_block).any().to_pandas()
24082402
else: # 'all'
@@ -3828,7 +3822,7 @@ def to_orc(self, path=None, **kwargs) -> bytes | None:
38283822
return as_pandas_default_index.to_orc(path, **kwargs)
38293823

38303824
def _apply_unary_op(self, operation: ops.UnaryOp) -> DataFrame:
3831-
block = self._block.multi_apply_unary_op(self._block.value_columns, operation)
3825+
block = self._block.multi_apply_unary_op(operation)
38323826
return DataFrame(block)
38333827

38343828
def _map_clustering_columns(

0 commit comments

Comments
 (0)