Skip to content

Commit 4abc26c

Browse files
port kurt, skew to new framework
1 parent cba77b8 commit 4abc26c

File tree

4 files changed

+104
-77
lines changed

4 files changed

+104
-77
lines changed

bigframes/core/array_value.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,21 @@ def compute_general_reduction(
279279
self,
280280
assignments: Sequence[ex.Expression],
281281
by_column_ids: typing.Sequence[str] = (),
282+
*,
283+
dropna: bool = False,
282284
):
283285
# Warning: this function does not check if the expression is a valid reduction, and may fail spectacularly on invalid inputs
286+
plan = self.node
287+
if dropna:
288+
for col_id in by_column_ids:
289+
plan = nodes.FilterNode(plan, ops.notnull_op.as_expr(col_id))
290+
284291
named_exprs = [
285292
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
286293
]
287294
# TODO: Push this to rewrite later to go from block expression to planning form
288295
new_root = expression_factoring.plan_general_aggregation(
289-
self.node, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
296+
plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
290297
)
291298
target_ids = tuple(named_expr.id for named_expr in named_exprs)
292299
return (ArrayValue(new_root), target_ids)

bigframes/core/block_transforms.py

Lines changed: 40 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -623,40 +623,28 @@ def skew(
623623
original_columns = skew_column_ids
624624
column_labels = block.select_columns(original_columns).column_labels
625625

626-
block, delta3_ids = _mean_delta_to_power(
627-
block, 3, original_columns, grouping_column_ids
628-
)
629626
# counts, moment3 for each column
630627
aggregations = []
631-
for i, col in enumerate(original_columns):
628+
for col in original_columns:
629+
delta3_expr = _mean_delta_to_power(3, col)
632630
count_agg = agg_expressions.UnaryAggregation(
633631
agg_ops.count_op,
634632
ex.deref(col),
635633
)
636634
moment3_agg = agg_expressions.UnaryAggregation(
637635
agg_ops.mean_op,
638-
ex.deref(delta3_ids[i]),
636+
delta3_expr,
639637
)
640638
variance_agg = agg_expressions.UnaryAggregation(
641639
agg_ops.PopVarOp(),
642640
ex.deref(col),
643641
)
644-
aggregations.extend([count_agg, moment3_agg, variance_agg])
642+
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
643+
aggregations.append(skew_expr)
645644

646-
block, agg_ids = block.aggregate(
647-
by_column_ids=grouping_column_ids, aggregations=aggregations
645+
block, _ = block.reduce_general(
646+
aggregations, grouping_column_ids, column_labels=column_labels
648647
)
649-
650-
skew_ids = []
651-
for i, col in enumerate(original_columns):
652-
# Corresponds to order of aggregations in preceding loop
653-
count_id, moment3_id, var_id = agg_ids[i * 3 : (i * 3) + 3]
654-
block, skew_id = _skew_from_moments_and_count(
655-
block, count_id, moment3_id, var_id
656-
)
657-
skew_ids.append(skew_id)
658-
659-
block = block.select_columns(skew_ids).with_column_labels(column_labels)
660648
if not grouping_column_ids:
661649
# When ungrouped, transpose result row into a series
662650
# perform transpose last, so as to not invalidate cache
@@ -673,36 +661,23 @@ def kurt(
673661
) -> blocks.Block:
674662
original_columns = skew_column_ids
675663
column_labels = block.select_columns(original_columns).column_labels
676-
677-
block, delta4_ids = _mean_delta_to_power(
678-
block, 4, original_columns, grouping_column_ids
679-
)
680664
# counts, moment4 for each column
681-
aggregations = []
682-
for i, col in enumerate(original_columns):
665+
kurt_exprs = []
666+
for col in original_columns:
667+
delta_4_expr = _mean_delta_to_power(4, col)
683668
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
684-
moment4_agg = agg_expressions.UnaryAggregation(
685-
agg_ops.mean_op, ex.deref(delta4_ids[i])
686-
)
669+
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
687670
variance_agg = agg_expressions.UnaryAggregation(
688671
agg_ops.PopVarOp(), ex.deref(col)
689672
)
690-
aggregations.extend([count_agg, moment4_agg, variance_agg])
691673

692-
block, agg_ids = block.aggregate(
693-
by_column_ids=grouping_column_ids, aggregations=aggregations
694-
)
695-
696-
kurt_ids = []
697-
for i, col in enumerate(original_columns):
698674
# Corresponds to order of aggregations in preceding loop
699-
count_id, moment4_id, var_id = agg_ids[i * 3 : (i * 3) + 3]
700-
block, kurt_id = _kurt_from_moments_and_count(
701-
block, count_id, moment4_id, var_id
702-
)
703-
kurt_ids.append(kurt_id)
675+
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
676+
kurt_exprs.append(kurt_expr)
704677

705-
block = block.select_columns(kurt_ids).with_column_labels(column_labels)
678+
block, _ = block.reduce_general(
679+
kurt_exprs, grouping_column_ids, column_labels=column_labels
680+
)
706681
if not grouping_column_ids:
707682
# When ungrouped, transpose result row into a series
708683
# perform transpose last, so as to not invalidate cache
@@ -713,38 +688,30 @@ def kurt(
713688

714689

715690
def _mean_delta_to_power(
716-
block: blocks.Block,
717691
n_power: int,
718-
column_ids: typing.Sequence[str],
719-
grouping_column_ids: typing.Sequence[str],
720-
) -> typing.Tuple[blocks.Block, typing.Sequence[str]]:
692+
val_id: str,
693+
) -> ex.Expression:
721694
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
722-
window = windows.unbound(grouping_keys=tuple(grouping_column_ids))
723-
block, mean_ids = block.multi_apply_window_op(column_ids, agg_ops.mean_op, window)
724-
delta_ids = []
725-
for val_id, mean_val_id in zip(column_ids, mean_ids):
726-
delta = ops.sub_op.as_expr(val_id, mean_val_id)
727-
delta_power = ops.pow_op.as_expr(delta, ex.const(n_power))
728-
block, delta_power_id = block.project_expr(delta_power)
729-
delta_ids.append(delta_power_id)
730-
return block, delta_ids
695+
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id))
696+
delta = ops.sub_op.as_expr(val_id, mean_expr)
697+
return ops.pow_op.as_expr(delta, ex.const(n_power))
731698

732699

733700
def _skew_from_moments_and_count(
734-
block: blocks.Block, count_id: str, moment3_id: str, moment2_id: str
735-
) -> typing.Tuple[blocks.Block, str]:
701+
count: ex.Expression, moment3: ex.Expression, moment2: ex.Expression
702+
) -> ex.Expression:
736703
# Calculate skew using count, third moment and population variance
737704
# See G1 estimator:
738705
# https://en.wikipedia.org/wiki/Skewness#Sample_skewness
739706
moments_estimator = ops.div_op.as_expr(
740-
moment3_id, ops.pow_op.as_expr(moment2_id, ex.const(3 / 2))
707+
moment3, ops.pow_op.as_expr(moment2, ex.const(3 / 2))
741708
)
742709

743-
countminus1 = ops.sub_op.as_expr(count_id, ex.const(1))
744-
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
710+
countminus1 = ops.sub_op.as_expr(count, ex.const(1))
711+
countminus2 = ops.sub_op.as_expr(count, ex.const(2))
745712
adjustment = ops.div_op.as_expr(
746713
ops.unsafe_pow_op.as_expr(
747-
ops.mul_op.as_expr(count_id, countminus1), ex.const(1 / 2)
714+
ops.mul_op.as_expr(count, countminus1), ex.const(1 / 2)
748715
),
749716
countminus2,
750717
)
@@ -753,14 +720,14 @@ def _skew_from_moments_and_count(
753720

754721
# Need to produce NA if have less than 3 data points
755722
cleaned_skew = ops.where_op.as_expr(
756-
skew, ops.ge_op.as_expr(count_id, ex.const(3)), ex.const(None)
723+
skew, ops.ge_op.as_expr(count, ex.const(3)), ex.const(None)
757724
)
758-
return block.project_expr(cleaned_skew)
725+
return cleaned_skew
759726

760727

761728
def _kurt_from_moments_and_count(
762-
block: blocks.Block, count_id: str, moment4_id: str, moment2_id: str
763-
) -> typing.Tuple[blocks.Block, str]:
729+
count: ex.Expression, moment4: ex.Expression, moment2: ex.Expression
730+
) -> ex.Expression:
764731
# Kurtosis is often defined as the second standardize moment: moment(4)/moment(2)**2
765732
# Pandas however uses Fisher’s estimator, implemented below
766733
# numerator = (count + 1) * (count - 1) * moment4
@@ -769,28 +736,26 @@ def _kurt_from_moments_and_count(
769736
# kurtosis = (numerator / denominator) - adjustment
770737

771738
numerator = ops.mul_op.as_expr(
772-
moment4_id,
739+
moment4,
773740
ops.mul_op.as_expr(
774-
ops.sub_op.as_expr(count_id, ex.const(1)),
775-
ops.add_op.as_expr(count_id, ex.const(1)),
741+
ops.sub_op.as_expr(count, ex.const(1)),
742+
ops.add_op.as_expr(count, ex.const(1)),
776743
),
777744
)
778745

779746
# Denominator
780-
countminus2 = ops.sub_op.as_expr(count_id, ex.const(2))
781-
countminus3 = ops.sub_op.as_expr(count_id, ex.const(3))
747+
countminus2 = ops.sub_op.as_expr(count, ex.const(2))
748+
countminus3 = ops.sub_op.as_expr(count, ex.const(3))
782749

783750
# Denominator
784751
denominator = ops.mul_op.as_expr(
785-
ops.unsafe_pow_op.as_expr(moment2_id, ex.const(2)),
752+
ops.unsafe_pow_op.as_expr(moment2, ex.const(2)),
786753
ops.mul_op.as_expr(countminus2, countminus3),
787754
)
788755

789756
# Adjustment
790757
adj_num = ops.mul_op.as_expr(
791-
ops.unsafe_pow_op.as_expr(
792-
ops.sub_op.as_expr(count_id, ex.const(1)), ex.const(2)
793-
),
758+
ops.unsafe_pow_op.as_expr(ops.sub_op.as_expr(count, ex.const(1)), ex.const(2)),
794759
ex.const(3),
795760
)
796761
adj_denom = ops.mul_op.as_expr(countminus2, countminus3)
@@ -801,9 +766,9 @@ def _kurt_from_moments_and_count(
801766

802767
# Need to produce NA if have less than 4 data points
803768
cleaned_kurt = ops.where_op.as_expr(
804-
kurt, ops.ge_op.as_expr(count_id, ex.const(4)), ex.const(None)
769+
kurt, ops.ge_op.as_expr(count, ex.const(4)), ex.const(None)
805770
)
806-
return block.project_expr(cleaned_kurt)
771+
return cleaned_kurt
807772

808773

809774
def align(

bigframes/core/blocks.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,58 @@ def project_block_exprs(
11671167
index_labels=self._index_labels,
11681168
)
11691169

1170+
# This is a new experimental version of the aggregate that supports mixing analytic and scalar expressions\
1171+
def reduce_general(
1172+
self,
1173+
aggregations: typing.Sequence[ex.Expression] = (),
1174+
by_column_ids: typing.Sequence[str] = (),
1175+
column_labels: Optional[pd.Index] = None,
1176+
*,
1177+
dropna: bool = True,
1178+
) -> typing.Tuple[Block, typing.Sequence[str]]:
1179+
"""
1180+
Apply aggregations to the block.
1181+
1182+
Arguments:
1183+
by_column_id: column id of the aggregation key, this is preserved through the transform and used as index.
1184+
aggregations: input_column_id, operation tuples
1185+
dropna: whether null keys should be dropped
1186+
1187+
Returns:
1188+
Tuple[Block, Sequence[str]]:
1189+
The first element is the grouped block. The second is the
1190+
column IDs corresponding to each applied aggregation.
1191+
"""
1192+
if column_labels is None:
1193+
column_labels = pd.Index(range(len(aggregations)))
1194+
1195+
result_expr, output_col_ids = self.expr.compute_general_reduction(
1196+
aggregations, by_column_ids, dropna=dropna
1197+
)
1198+
1199+
names: typing.List[Label] = []
1200+
if len(by_column_ids) == 0:
1201+
result_expr, label_id = result_expr.create_constant(0, pd.Int64Dtype())
1202+
index_columns = (label_id,)
1203+
names = [None]
1204+
else:
1205+
index_columns = tuple(by_column_ids) # type: ignore
1206+
for by_col_id in by_column_ids:
1207+
if by_col_id in self.value_columns:
1208+
names.append(self.col_id_to_label[by_col_id])
1209+
else:
1210+
names.append(self.col_id_to_index_name[by_col_id])
1211+
1212+
return (
1213+
Block(
1214+
result_expr,
1215+
index_columns=index_columns,
1216+
column_labels=column_labels,
1217+
index_labels=names,
1218+
),
1219+
[id.name for id in output_col_ids],
1220+
)
1221+
11701222
def apply_window_op(
11711223
self,
11721224
column: str,

bigframes/core/expression_factoring.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ def plan_general_aggregation(
6464
plan = nodes.ProjectionNode(
6565
plan, tuple((cdef.expression, cdef.id) for cdef in post_scalar_exprs)
6666
)
67+
final_ids = itertools.chain(
68+
(ref.id for ref in grouping_keys), (cdef.id for cdef in post_scalar_exprs)
69+
)
6770
plan = nodes.SelectionNode(
68-
plan, tuple(nodes.AliasedRef.identity(cdef.id) for cdef in post_scalar_exprs)
71+
plan, tuple(nodes.AliasedRef.identity(ident) for ident in final_ids)
6972
)
7073
return plan
7174

0 commit comments

Comments
 (0)