Skip to content

Commit 4eb64f6

Browse files
refactor: remove 'partial' ops and replace with expressions (#314)
1 parent 9150c16 commit 4eb64f6

File tree

13 files changed

+295
-358
lines changed

13 files changed

+295
-358
lines changed

bigframes/core/block_transforms.py

Lines changed: 66 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import bigframes.constants as constants
2121
import bigframes.core as core
2222
import bigframes.core.blocks as blocks
23+
import bigframes.core.expression as ex
2324
import bigframes.core.ordering as ordering
2425
import bigframes.core.window_spec as windows
2526
import bigframes.dtypes as dtypes
@@ -44,11 +45,10 @@ def equals(block1: blocks.Block, block2: blocks.Block) -> bool:
4445
for lcol, rcol in zip(block1.value_columns, block2.value_columns):
4546
lcolmapped = lmap[lcol]
4647
rcolmapped = rmap[rcol]
47-
joined_block, result_id = joined_block.apply_binary_op(
48-
lcolmapped, rcolmapped, ops.eq_null_match_op
49-
)
50-
joined_block, result_id = joined_block.apply_unary_op(
51-
result_id, ops.partial_right(ops.fillna_op, False)
48+
joined_block, result_id = joined_block.project_expr(
49+
ops.fillna_op.as_expr(
50+
ops.eq_null_match_op.as_expr(lcolmapped, rcolmapped), ex.const(False)
51+
)
5252
)
5353
equality_ids.append(result_id)
5454

@@ -91,9 +91,8 @@ def indicate_duplicates(
9191
agg_ops.count_op,
9292
window_spec=window_spec,
9393
)
94-
block, duplicate_indicator = block.apply_unary_op(
95-
val_count_col_id,
96-
ops.partial_right(ops.gt_op, 1),
94+
block, duplicate_indicator = block.project_expr(
95+
ops.gt_op.as_expr(val_count_col_id, ex.const(1))
9796
)
9897
return (
9998
block.drop_columns(
@@ -183,8 +182,8 @@ def _interpolate_column(
183182

184183
# Note, this method may
185184
block, notnull = block.apply_unary_op(column, ops.notnull_op)
186-
block, masked_offsets = block.apply_binary_op(
187-
x_values, notnull, ops.partial_arg3(ops.where_op, None)
185+
block, masked_offsets = block.project_expr(
186+
ops.where_op.as_expr(x_values, notnull, ex.const(None))
188187
)
189188

190189
block, previous_value = block.apply_window_op(
@@ -271,25 +270,22 @@ def _interpolate_points_nearest(
271270
xpredict_id: str,
272271
) -> typing.Tuple[blocks.Block, str]:
273272
"""Interpolate by taking the y value of the nearest x value"""
274-
block, left_diff = block.apply_binary_op(xpredict_id, x0_id, ops.sub_op)
275-
block, right_diff = block.apply_binary_op(x1_id, xpredict_id, ops.sub_op)
273+
left_diff = ops.sub_op.as_expr(xpredict_id, x0_id)
274+
right_diff = ops.sub_op.as_expr(x1_id, xpredict_id)
276275
# If diffs equal, choose left
277-
block, choose_left = block.apply_binary_op(left_diff, right_diff, ops.le_op)
278-
block, choose_left = block.apply_unary_op(
279-
choose_left, ops.partial_right(ops.fillna_op, False)
276+
choose_left = ops.fillna_op.as_expr(
277+
ops.le_op.as_expr(left_diff, right_diff), ex.const(False)
280278
)
281279

282-
block, nearest = block.apply_ternary_op(y0_id, choose_left, y1_id, ops.where_op)
283-
284-
block, y0_exists = block.apply_unary_op(y0_id, ops.notnull_op)
285-
block, y1_exists = block.apply_unary_op(y1_id, ops.notnull_op)
286-
block, is_interpolation = block.apply_binary_op(y0_exists, y1_exists, ops.and_op)
280+
nearest = ops.where_op.as_expr(y0_id, choose_left, y1_id)
287281

288-
block, prediction_id = block.apply_binary_op(
289-
nearest, is_interpolation, ops.partial_arg3(ops.where_op, None)
282+
is_interpolation = ops.and_op.as_expr(
283+
ops.notnull_op.as_expr(y0_id), ops.notnull_op.as_expr(y1_id)
290284
)
291285

292-
return block, prediction_id
286+
return block.project_expr(
287+
ops.where_op.as_expr(nearest, is_interpolation, ex.const(None))
288+
)
293289

294290

295291
def _interpolate_points_ffill(
@@ -302,11 +298,9 @@ def _interpolate_points_ffill(
302298
) -> typing.Tuple[blocks.Block, str]:
303299
"""Interpolates by using the preceding values"""
304300
# check for existance of y1, otherwise we are extrapolating instead of interpolating
305-
block, y1_exists = block.apply_unary_op(y1_id, ops.notnull_op)
306-
block, prediction_id = block.apply_binary_op(
307-
y0_id, y1_exists, ops.partial_arg3(ops.where_op, None)
301+
return block.project_expr(
302+
ops.where_op.as_expr(y0_id, ops.notnull_op.as_expr(y1_id), ex.const(None))
308303
)
309-
return block, prediction_id
310304

311305

312306
def drop_duplicates(
@@ -519,9 +513,7 @@ def nsmallest(
519513
agg_ops.rank_op,
520514
window_spec=windows.WindowSpec(ordering=tuple(order_refs)),
521515
)
522-
block, condition = block.apply_unary_op(
523-
counter, ops.partial_right(ops.le_op, n)
524-
)
516+
block, condition = block.project_expr(ops.le_op.as_expr(counter, ex.const(n)))
525517
block = block.filter(condition)
526518
return block.drop_columns([counter, condition])
527519

@@ -551,9 +543,7 @@ def nlargest(
551543
agg_ops.rank_op,
552544
window_spec=windows.WindowSpec(ordering=tuple(order_refs)),
553545
)
554-
block, condition = block.apply_unary_op(
555-
counter, ops.partial_right(ops.le_op, n)
556-
)
546+
block, condition = block.project_expr(ops.le_op.as_expr(counter, ex.const(n)))
557547
block = block.filter(condition)
558548
return block.drop_columns([counter, condition])
559549

@@ -641,19 +631,18 @@ def kurt(
641631

642632
def _mean_delta_to_power(
643633
block: blocks.Block,
644-
n_power,
634+
n_power: int,
645635
column_ids: typing.Sequence[str],
646636
grouping_column_ids: typing.Sequence[str],
647637
) -> typing.Tuple[blocks.Block, typing.Sequence[str]]:
648638
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
649639
window = windows.WindowSpec(grouping_keys=tuple(grouping_column_ids))
650640
block, mean_ids = block.multi_apply_window_op(column_ids, agg_ops.mean_op, window)
651641
delta_ids = []
652-
cube_op = ops.partial_right(ops.pow_op, n_power)
653642
for val_id, mean_val_id in zip(column_ids, mean_ids):
654-
block, delta_id = block.apply_binary_op(val_id, mean_val_id, ops.sub_op)
655-
block, delta_power_id = block.apply_unary_op(delta_id, cube_op)
656-
block = block.drop_columns([delta_id])
643+
delta = ops.sub_op.as_expr(val_id, mean_val_id)
644+
delta_power = ops.pow_op.as_expr(delta, ex.const(n_power))
645+
block, delta_power_id = block.project_expr(delta_power)
657646
delta_ids.append(delta_power_id)
658647
return block, delta_ids
659648

@@ -664,31 +653,26 @@ def _skew_from_moments_and_count(
664653
# Calculate skew using count, third moment and population variance
665654
# See G1 estimator:
666655
# https://en.wikipedia.org/wiki/Skewness#Sample_skewness
667-
block, denominator_id = block.apply_unary_op(
668-
moment2_id, ops.partial_right(ops.unsafe_pow_op, 3 / 2)
669-
)
670-
block, base_id = block.apply_binary_op(moment3_id, denominator_id, ops.div_op)
671-
block, countminus1_id = block.apply_unary_op(
672-
count_id, ops.partial_right(ops.sub_op, 1)
673-
)
674-
block, countminus2_id = block.apply_unary_op(
675-
count_id, ops.partial_right(ops.sub_op, 2)
676-
)
677-
block, adjustment_id = block.apply_binary_op(count_id, countminus1_id, ops.mul_op)
678-
block, adjustment_id = block.apply_unary_op(
679-
adjustment_id, ops.partial_right(ops.unsafe_pow_op, 1 / 2)
656+
moments_estimator = ops.div_op.as_expr(
657+
moment3_id, ops.pow_op.as_expr(moment2_id, ex.const(3 / 2))
680658
)
681-
block, adjustment_id = block.apply_binary_op(
682-
adjustment_id, countminus2_id, ops.div_op
659+
660+
countminus1 = ops.sub_op.as_expr(count_id, ex.const(1))
661+
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
662+
adjustment = ops.div_op.as_expr(
663+
ops.unsafe_pow_op.as_expr(
664+
ops.mul_op.as_expr(count_id, countminus1), ex.const(1 / 2)
665+
),
666+
countminus2,
683667
)
684-
block, skew_id = block.apply_binary_op(base_id, adjustment_id, ops.mul_op)
668+
669+
skew = ops.mul_op.as_expr(moments_estimator, adjustment)
685670

686671
# Need to produce NA if have less than 3 data points
687-
block, na_cond_id = block.apply_unary_op(count_id, ops.partial_right(ops.ge_op, 3))
688-
block, skew_id = block.apply_binary_op(
689-
skew_id, na_cond_id, ops.partial_arg3(ops.where_op, None)
672+
cleaned_skew = ops.where_op.as_expr(
673+
skew, ops.ge_op.as_expr(count_id, ex.const(3)), ex.const(None)
690674
)
691-
return block, skew_id
675+
return block.project_expr(cleaned_skew)
692676

693677

694678
def _kurt_from_moments_and_count(
@@ -701,49 +685,42 @@ def _kurt_from_moments_and_count(
701685
# adjustment = 3 * (count - 1) ** 2 / ((count - 2) * (count - 3))
702686
# kurtosis = (numerator / denominator) - adjustment
703687

704-
# Numerator
705-
block, countminus1_id = block.apply_unary_op(
706-
count_id, ops.partial_right(ops.sub_op, 1)
707-
)
708-
block, countplus1_id = block.apply_unary_op(
709-
count_id, ops.partial_right(ops.add_op, 1)
688+
numerator = ops.mul_op.as_expr(
689+
moment4_id,
690+
ops.mul_op.as_expr(
691+
ops.sub_op.as_expr(count_id, ex.const(1)),
692+
ops.add_op.as_expr(count_id, ex.const(1)),
693+
),
710694
)
711-
block, num_adj = block.apply_binary_op(countplus1_id, countminus1_id, ops.mul_op)
712-
block, numerator_id = block.apply_binary_op(moment4_id, num_adj, ops.mul_op)
713695

714696
# Denominator
715-
block, countminus2_id = block.apply_unary_op(
716-
count_id, ops.partial_right(ops.sub_op, 2)
717-
)
718-
block, countminus3_id = block.apply_unary_op(
719-
count_id, ops.partial_right(ops.sub_op, 3)
720-
)
721-
block, denom_adj = block.apply_binary_op(countminus2_id, countminus3_id, ops.mul_op)
722-
block, popvar_squared = block.apply_unary_op(
723-
moment2_id, ops.partial_right(ops.unsafe_pow_op, 2)
697+
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
698+
countminus3 = ops.sub_op.as_expr(count_id, ex.const(3))
699+
700+
# Denominator
701+
denominator = ops.mul_op.as_expr(
702+
ops.unsafe_pow_op.as_expr(moment2_id, ex.const(2)),
703+
ops.mul_op.as_expr(countminus2, countminus3),
724704
)
725-
block, denominator_id = block.apply_binary_op(popvar_squared, denom_adj, ops.mul_op)
726705

727706
# Adjustment
728-
block, countminus1_square = block.apply_unary_op(
729-
countminus1_id, ops.partial_right(ops.unsafe_pow_op, 2)
730-
)
731-
block, adj_num = block.apply_unary_op(
732-
countminus1_square, ops.partial_right(ops.mul_op, 3)
707+
adj_num = ops.mul_op.as_expr(
708+
ops.unsafe_pow_op.as_expr(
709+
ops.sub_op.as_expr(count_id, ex.const(1)), ex.const(2)
710+
),
711+
ex.const(3),
733712
)
734-
block, adj_denom = block.apply_binary_op(countminus2_id, countminus3_id, ops.mul_op)
735-
block, adjustment_id = block.apply_binary_op(adj_num, adj_denom, ops.div_op)
713+
adj_denom = ops.mul_op.as_expr(countminus2, countminus3)
714+
adjustment = ops.div_op.as_expr(adj_num, adj_denom)
736715

737716
# Combine
738-
block, base_id = block.apply_binary_op(numerator_id, denominator_id, ops.div_op)
739-
block, kurt_id = block.apply_binary_op(base_id, adjustment_id, ops.sub_op)
717+
kurt = ops.sub_op.as_expr(ops.div_op.as_expr(numerator, denominator), adjustment)
740718

741719
# Need to produce NA if have less than 4 data points
742-
block, na_cond_id = block.apply_unary_op(count_id, ops.partial_right(ops.ge_op, 4))
743-
block, kurt_id = block.apply_binary_op(
744-
kurt_id, na_cond_id, ops.partial_arg3(ops.where_op, None)
720+
cleaned_kurt = ops.where_op.as_expr(
721+
kurt, ops.ge_op.as_expr(count_id, ex.const(4)), ex.const(None)
745722
)
746-
return block, kurt_id
723+
return block.project_expr(cleaned_kurt)
747724

748725

749726
def align(

0 commit comments

Comments
 (0)