Skip to content

Commit 2d07bb3

Browse files
perf: Fragmentize multiple col expressions as one (#2333)
1 parent 3c54c68 commit 2d07bb3

File tree

3 files changed

+111
-92
lines changed

3 files changed

+111
-92
lines changed

bigframes/core/block_transforms.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -625,21 +625,7 @@ def skew(
625625
# counts, moment3 for each column
626626
aggregations = []
627627
for col in original_columns:
628-
delta3_expr = _mean_delta_to_power(3, col)
629-
count_agg = agg_expressions.UnaryAggregation(
630-
agg_ops.count_op,
631-
ex.deref(col),
632-
)
633-
moment3_agg = agg_expressions.UnaryAggregation(
634-
agg_ops.mean_op,
635-
delta3_expr,
636-
)
637-
variance_agg = agg_expressions.UnaryAggregation(
638-
agg_ops.PopVarOp(),
639-
ex.deref(col),
640-
)
641-
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
642-
aggregations.append(skew_expr)
628+
aggregations.append(skew_expr(ex.deref(col)))
643629

644630
block = block.aggregate(
645631
aggregations, grouping_column_ids, column_labels=column_labels
@@ -663,16 +649,7 @@ def kurt(
663649
# counts, moment4 for each column
664650
kurt_exprs = []
665651
for col in original_columns:
666-
delta_4_expr = _mean_delta_to_power(4, col)
667-
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
668-
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
669-
variance_agg = agg_expressions.UnaryAggregation(
670-
agg_ops.PopVarOp(), ex.deref(col)
671-
)
672-
673-
# Corresponds to order of aggregations in preceding loop
674-
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
675-
kurt_exprs.append(kurt_expr)
652+
kurt_exprs.append(kurt_expr(ex.deref(col)))
676653

677654
block = block.aggregate(
678655
kurt_exprs, grouping_column_ids, column_labels=column_labels
@@ -686,13 +663,38 @@ def kurt(
686663
return block
687664

688665

666+
def skew_expr(expr: ex.Expression) -> ex.Expression:
667+
delta3_expr = _mean_delta_to_power(3, expr)
668+
count_agg = agg_expressions.UnaryAggregation(
669+
agg_ops.count_op,
670+
expr,
671+
)
672+
moment3_agg = agg_expressions.UnaryAggregation(
673+
agg_ops.mean_op,
674+
delta3_expr,
675+
)
676+
variance_agg = agg_expressions.UnaryAggregation(
677+
agg_ops.PopVarOp(),
678+
expr,
679+
)
680+
return _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
681+
682+
683+
def kurt_expr(expr: ex.Expression) -> ex.Expression:
684+
delta_4_expr = _mean_delta_to_power(4, expr)
685+
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, expr)
686+
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
687+
variance_agg = agg_expressions.UnaryAggregation(agg_ops.PopVarOp(), expr)
688+
return _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
689+
690+
689691
def _mean_delta_to_power(
690692
n_power: int,
691-
val_id: str,
693+
col_expr: ex.Expression,
692694
) -> ex.Expression:
693695
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
694-
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id))
695-
delta = ops.sub_op.as_expr(val_id, mean_expr)
696+
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, col_expr)
697+
delta = ops.sub_op.as_expr(col_expr, mean_expr)
696698
return ops.pow_op.as_expr(delta, ex.const(n_power))
697699

698700

bigframes/core/expression.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
import collections
1918
import dataclasses
2019
import functools
2120
import itertools
2221
import typing
23-
from typing import Callable, Dict, Generator, Mapping, Tuple, TypeVar, Union
22+
from typing import Callable, Generator, Mapping, TypeVar, Union
2423

2524
import pandas as pd
2625

@@ -162,57 +161,6 @@ def walk(self) -> Generator[Expression, None, None]:
162161
for child in self.children:
163162
yield from child.children
164163

165-
def unique_nodes(
166-
self: Expression,
167-
) -> Generator[Expression, None, None]:
168-
"""Walks the tree for unique nodes"""
169-
seen = set()
170-
stack: list[Expression] = [self]
171-
while stack:
172-
item = stack.pop()
173-
if item not in seen:
174-
yield item
175-
seen.add(item)
176-
stack.extend(item.children)
177-
178-
def iter_nodes_topo(
179-
self: Expression,
180-
) -> Generator[Expression, None, None]:
181-
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
182-
child_to_parents: Dict[Expression, list[Expression]] = collections.defaultdict(
183-
list
184-
)
185-
out_degree: Dict[Expression, int] = collections.defaultdict(int)
186-
187-
queue: collections.deque["Expression"] = collections.deque()
188-
for node in list(self.unique_nodes()):
189-
num_children = len(node.children)
190-
out_degree[node] = num_children
191-
if num_children == 0:
192-
queue.append(node)
193-
for child in node.children:
194-
child_to_parents[child].append(node)
195-
196-
while queue:
197-
item = queue.popleft()
198-
yield item
199-
parents = child_to_parents.get(item, [])
200-
for parent in parents:
201-
out_degree[parent] -= 1
202-
if out_degree[parent] == 0:
203-
queue.append(parent)
204-
205-
def reduce_up(self, reduction: Callable[[Expression, Tuple[T, ...]], T]) -> T:
206-
"""Apply a bottom-up reduction to the tree."""
207-
results: dict[Expression, T] = {}
208-
for node in list(self.iter_nodes_topo()):
209-
# child nodes have already been transformed
210-
child_results = tuple(results[child] for child in node.children)
211-
result = reduction(node, child_results)
212-
results[node] = result
213-
214-
return results[self]
215-
216164

217165
@dataclasses.dataclass(frozen=True)
218166
class ScalarConstantExpression(Expression):

bigframes/core/expression_factoring.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import functools
1919
import itertools
2020
from typing import (
21+
Callable,
2122
cast,
23+
Dict,
24+
Generator,
2225
Hashable,
2326
Iterable,
2427
Iterator,
@@ -40,18 +43,72 @@
4043

4144
_MAX_INLINE_COMPLEXITY = 10
4245

46+
T = TypeVar("T")
47+
48+
49+
def unique_nodes(
50+
roots: Sequence[expression.Expression],
51+
) -> Generator[expression.Expression, None, None]:
52+
"""Walks the tree for unique nodes"""
53+
seen = set()
54+
stack: list[expression.Expression] = list(roots)
55+
while stack:
56+
item = stack.pop()
57+
if item not in seen:
58+
yield item
59+
seen.add(item)
60+
stack.extend(item.children)
61+
62+
63+
def iter_nodes_topo(
64+
roots: Sequence[expression.Expression],
65+
) -> Generator[expression.Expression, None, None]:
66+
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
67+
child_to_parents: Dict[
68+
expression.Expression, list[expression.Expression]
69+
] = collections.defaultdict(list)
70+
out_degree: Dict[expression.Expression, int] = collections.defaultdict(int)
71+
72+
queue: collections.deque[expression.Expression] = collections.deque()
73+
for node in unique_nodes(roots):
74+
num_children = len(node.children)
75+
out_degree[node] = num_children
76+
if num_children == 0:
77+
queue.append(node)
78+
for child in node.children:
79+
child_to_parents[child].append(node)
80+
81+
while queue:
82+
item = queue.popleft()
83+
yield item
84+
parents = child_to_parents.get(item, [])
85+
for parent in parents:
86+
out_degree[parent] -= 1
87+
if out_degree[parent] == 0:
88+
queue.append(parent)
89+
90+
91+
def reduce_up(
92+
roots: Sequence[expression.Expression],
93+
reduction: Callable[[expression.Expression, Tuple[T, ...]], T],
94+
) -> Tuple[T, ...]:
95+
"""Apply a bottom-up reduction to the forest."""
96+
results: dict[expression.Expression, T] = {}
97+
for node in list(iter_nodes_topo(roots)):
98+
# child nodes have already been transformed
99+
child_results = tuple(results[child] for child in node.children)
100+
result = reduction(node, child_results)
101+
results[node] = result
102+
103+
return tuple(results[root] for root in roots)
104+
43105

44106
def apply_col_exprs_to_plan(
45107
plan: nodes.BigFrameNode, col_exprs: Sequence[nodes.ColumnDef]
46108
) -> nodes.BigFrameNode:
47-
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
48109
target_ids = tuple(named_expr.id for named_expr in col_exprs)
49110

50-
fragments = tuple(
51-
itertools.chain.from_iterable(
52-
fragmentize_expression(expr) for expr in col_exprs
53-
)
54-
)
111+
fragments = fragmentize_expression(col_exprs)
55112
return push_into_tree(plan, fragments, target_ids)
56113

57114

@@ -101,14 +158,26 @@ class FactoredExpression:
101158
sub_exprs: Tuple[nodes.ColumnDef, ...]
102159

103160

104-
def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]:
161+
def fragmentize_expression(
162+
roots: Sequence[nodes.ColumnDef],
163+
) -> Sequence[nodes.ColumnDef]:
105164
"""
106165
The goal of this functions is to factor out an expression into multiple sub-expressions.
107166
"""
108-
109-
factored_expr = root.expression.reduce_up(gather_fragments)
110-
root_expr = nodes.ColumnDef(factored_expr.root_expr, root.id)
111-
return (root_expr, *factored_expr.sub_exprs)
167+
# TODO: Fragmentize a bit less aggressively
168+
factored_exprs = reduce_up([root.expression for root in roots], gather_fragments)
169+
root_exprs = (
170+
nodes.ColumnDef(factored.root_expr, root.id)
171+
for factored, root in zip(factored_exprs, roots)
172+
)
173+
return (
174+
*root_exprs,
175+
*dedupe(
176+
itertools.chain.from_iterable(
177+
factored_expr.sub_exprs for factored_expr in factored_exprs
178+
)
179+
),
180+
)
112181

113182

114183
@dataclasses.dataclass(frozen=True, eq=False)

0 commit comments

Comments
 (0)