-
Notifications
You must be signed in to change notification settings - Fork 63
feat: Auto-plan complex reduction expressions #2298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
cba77b8
4abc26c
c9a9921
b2e4efa
1e53fa9
bf3c6bb
0b040c0
60ba975
37869a7
1d0514b
a7a0399
6e76219
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
| from dataclasses import dataclass | ||
| import datetime | ||
| import functools | ||
| import itertools | ||
| import typing | ||
| from typing import Iterable, List, Mapping, Optional, Sequence, Tuple | ||
|
|
||
|
|
@@ -271,15 +270,32 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]): | |
| nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments | ||
| ] | ||
| # TODO: Push this to rewrite later to go from block expression to planning form | ||
| # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions | ||
| fragments = tuple( | ||
| itertools.chain.from_iterable( | ||
| expression_factoring.fragmentize_expression(expr) | ||
| for expr in named_exprs | ||
| ) | ||
| new_root = expression_factoring.plan_general_col_exprs(self.node, named_exprs) | ||
|
|
||
| target_ids = tuple(named_expr.id for named_expr in named_exprs) | ||
| return (ArrayValue(new_root), target_ids) | ||
|
|
||
| def compute_general_reduction( | ||
| self, | ||
| assignments: Sequence[ex.Expression], | ||
| by_column_ids: typing.Sequence[str] = (), | ||
| *, | ||
| dropna: bool = False, | ||
| ): | ||
| # Warning: this function does not check if the expression is a valid reduction, and may fail spectacularly on invalid inputs | ||
|
||
| plan = self.node | ||
| if dropna: | ||
| for col_id in by_column_ids: | ||
| plan = nodes.FilterNode(plan, ops.notnull_op.as_expr(col_id)) | ||
|
|
||
| named_exprs = [ | ||
| nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments | ||
| ] | ||
| # TODO: Push this to rewrite later to go from block expression to planning form | ||
| new_root = expression_factoring.plan_general_aggregation( | ||
| plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids] | ||
| ) | ||
| target_ids = tuple(named_expr.id for named_expr in named_exprs) | ||
| new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids) | ||
| return (ArrayValue(new_root), target_ids) | ||
|
|
||
| def project_to_id(self, expression: ex.Expression): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -623,40 +623,28 @@ def skew( | |
| original_columns = skew_column_ids | ||
| column_labels = block.select_columns(original_columns).column_labels | ||
|
|
||
| block, delta3_ids = _mean_delta_to_power( | ||
| block, 3, original_columns, grouping_column_ids | ||
| ) | ||
| # counts, moment3 for each column | ||
| aggregations = [] | ||
| for i, col in enumerate(original_columns): | ||
| for col in original_columns: | ||
| delta3_expr = _mean_delta_to_power(3, col) | ||
| count_agg = agg_expressions.UnaryAggregation( | ||
| agg_ops.count_op, | ||
| ex.deref(col), | ||
| ) | ||
| moment3_agg = agg_expressions.UnaryAggregation( | ||
| agg_ops.mean_op, | ||
| ex.deref(delta3_ids[i]), | ||
| delta3_expr, | ||
| ) | ||
| variance_agg = agg_expressions.UnaryAggregation( | ||
| agg_ops.PopVarOp(), | ||
| ex.deref(col), | ||
| ) | ||
| aggregations.extend([count_agg, moment3_agg, variance_agg]) | ||
| skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg) | ||
| aggregations.append(skew_expr) | ||
|
|
||
| block, agg_ids = block.aggregate( | ||
| by_column_ids=grouping_column_ids, aggregations=aggregations | ||
| block, _ = block.reduce_general( | ||
|
||
| aggregations, grouping_column_ids, column_labels=column_labels | ||
| ) | ||
|
|
||
| skew_ids = [] | ||
| for i, col in enumerate(original_columns): | ||
| # Corresponds to order of aggregations in preceding loop | ||
| count_id, moment3_id, var_id = agg_ids[i * 3 : (i * 3) + 3] | ||
| block, skew_id = _skew_from_moments_and_count( | ||
| block, count_id, moment3_id, var_id | ||
| ) | ||
| skew_ids.append(skew_id) | ||
|
|
||
| block = block.select_columns(skew_ids).with_column_labels(column_labels) | ||
| if not grouping_column_ids: | ||
| # When ungrouped, transpose result row into a series | ||
| # perform transpose last, so as to not invalidate cache | ||
|
|
@@ -673,36 +661,23 @@ def kurt( | |
| ) -> blocks.Block: | ||
| original_columns = skew_column_ids | ||
| column_labels = block.select_columns(original_columns).column_labels | ||
|
|
||
| block, delta4_ids = _mean_delta_to_power( | ||
| block, 4, original_columns, grouping_column_ids | ||
| ) | ||
| # counts, moment4 for each column | ||
| aggregations = [] | ||
| for i, col in enumerate(original_columns): | ||
| kurt_exprs = [] | ||
| for col in original_columns: | ||
| delta_4_expr = _mean_delta_to_power(4, col) | ||
| count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col)) | ||
| moment4_agg = agg_expressions.UnaryAggregation( | ||
| agg_ops.mean_op, ex.deref(delta4_ids[i]) | ||
| ) | ||
| moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr) | ||
| variance_agg = agg_expressions.UnaryAggregation( | ||
| agg_ops.PopVarOp(), ex.deref(col) | ||
| ) | ||
| aggregations.extend([count_agg, moment4_agg, variance_agg]) | ||
|
|
||
| block, agg_ids = block.aggregate( | ||
| by_column_ids=grouping_column_ids, aggregations=aggregations | ||
| ) | ||
|
|
||
| kurt_ids = [] | ||
| for i, col in enumerate(original_columns): | ||
| # Corresponds to order of aggregations in preceding loop | ||
| count_id, moment4_id, var_id = agg_ids[i * 3 : (i * 3) + 3] | ||
| block, kurt_id = _kurt_from_moments_and_count( | ||
| block, count_id, moment4_id, var_id | ||
| ) | ||
| kurt_ids.append(kurt_id) | ||
| kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg) | ||
| kurt_exprs.append(kurt_expr) | ||
|
|
||
| block = block.select_columns(kurt_ids).with_column_labels(column_labels) | ||
| block, _ = block.reduce_general( | ||
|
||
| kurt_exprs, grouping_column_ids, column_labels=column_labels | ||
| ) | ||
| if not grouping_column_ids: | ||
| # When ungrouped, transpose result row into a series | ||
| # perform transpose last, so as to not invalidate cache | ||
|
|
@@ -713,38 +688,30 @@ def kurt( | |
|
|
||
|
|
||
| def _mean_delta_to_power( | ||
| block: blocks.Block, | ||
| n_power: int, | ||
| column_ids: typing.Sequence[str], | ||
| grouping_column_ids: typing.Sequence[str], | ||
| ) -> typing.Tuple[blocks.Block, typing.Sequence[str]]: | ||
| val_id: str, | ||
| ) -> ex.Expression: | ||
| """Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis.""" | ||
| window = windows.unbound(grouping_keys=tuple(grouping_column_ids)) | ||
| block, mean_ids = block.multi_apply_window_op(column_ids, agg_ops.mean_op, window) | ||
| delta_ids = [] | ||
| for val_id, mean_val_id in zip(column_ids, mean_ids): | ||
| delta = ops.sub_op.as_expr(val_id, mean_val_id) | ||
| delta_power = ops.pow_op.as_expr(delta, ex.const(n_power)) | ||
| block, delta_power_id = block.project_expr(delta_power) | ||
| delta_ids.append(delta_power_id) | ||
| return block, delta_ids | ||
| mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id)) | ||
| delta = ops.sub_op.as_expr(val_id, mean_expr) | ||
| return ops.pow_op.as_expr(delta, ex.const(n_power)) | ||
tswast marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _skew_from_moments_and_count( | ||
| block: blocks.Block, count_id: str, moment3_id: str, moment2_id: str | ||
| ) -> typing.Tuple[blocks.Block, str]: | ||
| count: ex.Expression, moment3: ex.Expression, moment2: ex.Expression | ||
| ) -> ex.Expression: | ||
| # Calculate skew using count, third moment and population variance | ||
| # See G1 estimator: | ||
| # https://en.wikipedia.org/wiki/Skewness#Sample_skewness | ||
| moments_estimator = ops.div_op.as_expr( | ||
| moment3_id, ops.pow_op.as_expr(moment2_id, ex.const(3 / 2)) | ||
| moment3, ops.pow_op.as_expr(moment2, ex.const(3 / 2)) | ||
| ) | ||
|
|
||
| countminus1 = ops.sub_op.as_expr(count_id, ex.const(1)) | ||
| countminus2 = ops.sub_op.as_expr(count_id, ex.const(2)) | ||
| countminus1 = ops.sub_op.as_expr(count, ex.const(1)) | ||
| countminus2 = ops.sub_op.as_expr(count, ex.const(2)) | ||
| adjustment = ops.div_op.as_expr( | ||
| ops.unsafe_pow_op.as_expr( | ||
| ops.mul_op.as_expr(count_id, countminus1), ex.const(1 / 2) | ||
| ops.mul_op.as_expr(count, countminus1), ex.const(1 / 2) | ||
| ), | ||
| countminus2, | ||
| ) | ||
|
|
@@ -753,14 +720,14 @@ def _skew_from_moments_and_count( | |
|
|
||
| # Need to produce NA if have less than 3 data points | ||
| cleaned_skew = ops.where_op.as_expr( | ||
| skew, ops.ge_op.as_expr(count_id, ex.const(3)), ex.const(None) | ||
| skew, ops.ge_op.as_expr(count, ex.const(3)), ex.const(None) | ||
| ) | ||
| return block.project_expr(cleaned_skew) | ||
| return cleaned_skew | ||
|
|
||
|
|
||
| def _kurt_from_moments_and_count( | ||
| block: blocks.Block, count_id: str, moment4_id: str, moment2_id: str | ||
| ) -> typing.Tuple[blocks.Block, str]: | ||
| count: ex.Expression, moment4: ex.Expression, moment2: ex.Expression | ||
| ) -> ex.Expression: | ||
| # Kurtosis is often defined as the second standardize moment: moment(4)/moment(2)**2 | ||
| # Pandas however uses Fisher’s estimator, implemented below | ||
| # numerator = (count + 1) * (count - 1) * moment4 | ||
|
|
@@ -769,28 +736,26 @@ def _kurt_from_moments_and_count( | |
| # kurtosis = (numerator / denominator) - adjustment | ||
|
|
||
| numerator = ops.mul_op.as_expr( | ||
| moment4_id, | ||
| moment4, | ||
| ops.mul_op.as_expr( | ||
| ops.sub_op.as_expr(count_id, ex.const(1)), | ||
| ops.add_op.as_expr(count_id, ex.const(1)), | ||
| ops.sub_op.as_expr(count, ex.const(1)), | ||
| ops.add_op.as_expr(count, ex.const(1)), | ||
| ), | ||
| ) | ||
|
|
||
| # Denominator | ||
| countminus2 = ops.sub_op.as_expr(count_id, ex.const(2)) | ||
| countminus3 = ops.sub_op.as_expr(count_id, ex.const(3)) | ||
| countminus2 = ops.sub_op.as_expr(count, ex.const(2)) | ||
| countminus3 = ops.sub_op.as_expr(count, ex.const(3)) | ||
|
|
||
| # Denominator | ||
| denominator = ops.mul_op.as_expr( | ||
| ops.unsafe_pow_op.as_expr(moment2_id, ex.const(2)), | ||
| ops.unsafe_pow_op.as_expr(moment2, ex.const(2)), | ||
| ops.mul_op.as_expr(countminus2, countminus3), | ||
| ) | ||
|
|
||
| # Adjustment | ||
| adj_num = ops.mul_op.as_expr( | ||
| ops.unsafe_pow_op.as_expr( | ||
| ops.sub_op.as_expr(count_id, ex.const(1)), ex.const(2) | ||
| ), | ||
| ops.unsafe_pow_op.as_expr(ops.sub_op.as_expr(count, ex.const(1)), ex.const(2)), | ||
| ex.const(3), | ||
| ) | ||
| adj_denom = ops.mul_op.as_expr(countminus2, countminus3) | ||
|
|
@@ -801,9 +766,9 @@ def _kurt_from_moments_and_count( | |
|
|
||
| # Need to produce NA if have less than 4 data points | ||
| cleaned_kurt = ops.where_op.as_expr( | ||
| kurt, ops.ge_op.as_expr(count_id, ex.const(4)), ex.const(None) | ||
| kurt, ops.ge_op.as_expr(count, ex.const(4)), ex.const(None) | ||
| ) | ||
| return block.project_expr(cleaned_kurt) | ||
| return cleaned_kurt | ||
|
|
||
|
|
||
| def align( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1167,6 +1167,45 @@ def project_block_exprs( | |
| index_labels=self._index_labels, | ||
| ) | ||
|
|
||
| # This is a new experimental version of the aggregate that supports mixing analytic and scalar expressions\ | ||
|
||
| def reduce_general( | ||
| self, | ||
| aggregations: typing.Sequence[ex.Expression] = (), | ||
tswast marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| by_column_ids: typing.Sequence[str] = (), | ||
| column_labels: Optional[pd.Index] = None, | ||
| *, | ||
| dropna: bool = True, | ||
| ) -> typing.Tuple[Block, typing.Sequence[str]]: | ||
| if column_labels is None: | ||
| column_labels = pd.Index(range(len(aggregations))) | ||
|
|
||
| result_expr, output_col_ids = self.expr.compute_general_reduction( | ||
| aggregations, by_column_ids, dropna=dropna | ||
| ) | ||
|
|
||
| names: typing.List[Label] = [] | ||
| if len(by_column_ids) == 0: | ||
tswast marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result_expr, label_id = result_expr.create_constant(0, pd.Int64Dtype()) | ||
| index_columns = (label_id,) | ||
| names = [None] | ||
| else: | ||
| index_columns = tuple(by_column_ids) # type: ignore | ||
| for by_col_id in by_column_ids: | ||
| if by_col_id in self.value_columns: | ||
| names.append(self.col_id_to_label[by_col_id]) | ||
| else: | ||
| names.append(self.col_id_to_index_name[by_col_id]) | ||
|
|
||
| return ( | ||
| Block( | ||
| result_expr, | ||
| index_columns=index_columns, | ||
| column_labels=column_labels, | ||
| index_labels=names, | ||
| ), | ||
| [id.name for id in output_col_ids], | ||
| ) | ||
|
|
||
| def apply_window_op( | ||
| self, | ||
| column: str, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.