Skip to content

Commit cba77b8

Browse files
feat: Auto-plan complex reduction expressions
1 parent b487cf1 commit cba77b8

File tree

3 files changed

+180
-9
lines changed

3 files changed

+180
-9
lines changed

bigframes/core/array_value.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from dataclasses import dataclass
1717
import datetime
1818
import functools
19-
import itertools
2019
import typing
2120
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
2221

@@ -271,15 +270,25 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]):
271270
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
272271
]
273272
# TODO: Push this to rewrite later to go from block expression to planning form
274-
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
275-
fragments = tuple(
276-
itertools.chain.from_iterable(
277-
expression_factoring.fragmentize_expression(expr)
278-
for expr in named_exprs
279-
)
273+
new_root = expression_factoring.plan_general_col_exprs(self.node, named_exprs)
274+
275+
target_ids = tuple(named_expr.id for named_expr in named_exprs)
276+
return (ArrayValue(new_root), target_ids)
277+
278+
def compute_general_reduction(
279+
self,
280+
assignments: Sequence[ex.Expression],
281+
by_column_ids: typing.Sequence[str] = (),
282+
):
283+
# Warning: this function does not check if the expression is a valid reduction, and may fail spectacularly on invalid inputs
284+
named_exprs = [
285+
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
286+
]
287+
# TODO: Push this to rewrite later to go from block expression to planning form
288+
new_root = expression_factoring.plan_general_aggregation(
289+
self.node, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
280290
)
281291
target_ids = tuple(named_expr.id for named_expr in named_exprs)
282-
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
283292
return (ArrayValue(new_root), target_ids)
284293

285294
def project_to_id(self, expression: ex.Expression):

bigframes/core/expression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def bottom_up(self, t: Callable[[Expression], Expression]) -> Expression:
152152
expr = t(expr)
153153
return expr
154154

155+
def top_down(self, t: Callable[[Expression], Expression]) -> Expression:
156+
expr = t(self)
157+
expr = expr.transform_children(lambda child: child.top_down(t))
158+
return expr
159+
155160
def walk(self) -> Generator[Expression, None, None]:
156161
yield self
157162
for child in self.children:

bigframes/core/expression_factoring.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,75 @@
11
import collections
22
import dataclasses
33
import functools
4-
from typing import cast, Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar
4+
import itertools
5+
from typing import (
6+
cast,
7+
Generic,
8+
Hashable,
9+
Iterable,
10+
Iterator,
11+
Mapping,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
TypeVar,
16+
)
517

618
from bigframes.core import agg_expressions, expression, identifiers, nodes, window_spec
719

820
_MAX_INLINE_COMPLEXITY = 10
921

1022

23+
def plan_general_col_exprs(
24+
plan: nodes.BigFrameNode, col_exprs: Sequence[nodes.ColumnDef]
25+
) -> nodes.BigFrameNode:
26+
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
27+
target_ids = tuple(named_expr.id for named_expr in col_exprs)
28+
29+
fragments = tuple(
30+
itertools.chain.from_iterable(
31+
fragmentize_expression(expr) for expr in col_exprs
32+
)
33+
)
34+
return push_into_tree(plan, fragments, target_ids)
35+
36+
37+
def plan_general_aggregation(
38+
plan: nodes.BigFrameNode,
39+
agg_defs: Sequence[nodes.ColumnDef],
40+
grouping_keys: Sequence[expression.DerefOp],
41+
) -> nodes.BigFrameNode:
42+
factored_aggs = [factor_aggregation(agg_def) for agg_def in agg_defs]
43+
all_inputs = list(
44+
itertools.chain(*(factored_agg.agg_inputs for factored_agg in factored_aggs))
45+
)
46+
# TODO: Windowize
47+
window_def = window_spec.WindowSpec(grouping_keys=tuple(grouping_keys))
48+
windowized_inputs = [
49+
nodes.ColumnDef(windowize(cdef.expression, window_def), cdef.id)
50+
for cdef in all_inputs
51+
]
52+
plan = plan_general_col_exprs(plan, windowized_inputs)
53+
all_aggs = list(
54+
itertools.chain(*(factored_agg.agg_exprs for factored_agg in factored_aggs))
55+
)
56+
plan = nodes.AggregateNode(
57+
plan,
58+
tuple((cdef.expression, cdef.id) for cdef in all_aggs), # type: ignore
59+
by_column_ids=tuple(grouping_keys),
60+
)
61+
post_scalar_exprs = tuple(
62+
(factored_agg.root_scalar_expr for factored_agg in factored_aggs)
63+
)
64+
plan = nodes.ProjectionNode(
65+
plan, tuple((cdef.expression, cdef.id) for cdef in post_scalar_exprs)
66+
)
67+
plan = nodes.SelectionNode(
68+
plan, tuple(nodes.AliasedRef.identity(cdef.id) for cdef in post_scalar_exprs)
69+
)
70+
return plan
71+
72+
1173
@dataclasses.dataclass(frozen=True, eq=False)
1274
class FactoredExpression:
1375
root_expr: expression.Expression
@@ -24,6 +86,101 @@ def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]:
2486
return (root_expr, *factored_expr.sub_exprs)
2587

2688

89+
@dataclasses.dataclass(frozen=True, eq=False)
90+
class FactoredAggregation:
91+
# pure scalar expression
92+
root_scalar_expr: nodes.ColumnDef
93+
# pure agg expression, only refs cols and consts
94+
agg_exprs: Tuple[nodes.ColumnDef, ...]
95+
# can be analytic, scalar op, const, col refs
96+
agg_inputs: Tuple[nodes.ColumnDef, ...]
97+
98+
99+
def windowize(
100+
root: expression.Expression, window: window_spec.WindowSpec
101+
) -> expression.Expression:
102+
def windowize_local(expr: expression.Expression):
103+
if isinstance(expr, agg_expressions.Aggregation):
104+
return agg_expressions.WindowExpression(expr, window)
105+
if isinstance(expr, agg_expressions.WindowExpression):
106+
raise ValueError(f"Expression {expr} already windowed!")
107+
return expr
108+
109+
return root.bottom_up(windowize_local)
110+
111+
112+
def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
113+
"""
114+
Factor an aggregation def into three components.
115+
1. Input column expressions (includes analytic expressions)
116+
2. The set of underlying primitive aggregations
117+
3. A final post-aggregate scalar expression
118+
"""
119+
final_aggs = set(find_final_aggregations(root.expression))
120+
agg_inputs = set(
121+
itertools.chain.from_iterable(map(find_final_aggregations, final_aggs))
122+
)
123+
124+
agg_input_defs = tuple(
125+
nodes.ColumnDef(expr, identifiers.ColumnId.unique()) for expr in agg_inputs
126+
)
127+
agg_inputs_dict = {
128+
cdef.expression: expression.DerefOp(cdef.id) for cdef in agg_input_defs
129+
}
130+
131+
isolated_aggs = tuple(
132+
nodes.ColumnDef(
133+
sub_expressions(expr, agg_inputs_dict), identifiers.ColumnId.unique()
134+
)
135+
for expr in agg_inputs
136+
)
137+
agg_outputs_dict = {
138+
cdef.expression: expression.DerefOp(cdef.id) for cdef in isolated_aggs
139+
}
140+
141+
root_scalar_expr = nodes.ColumnDef(
142+
sub_expressions(root.expression, agg_outputs_dict), root.id
143+
)
144+
145+
return FactoredAggregation(
146+
root_scalar_expr=root_scalar_expr,
147+
agg_exprs=isolated_aggs,
148+
agg_inputs=agg_input_defs,
149+
)
150+
151+
152+
def sub_expressions(
153+
root: expression.Expression,
154+
replacements: Mapping[expression.Expression, expression.Expression],
155+
) -> expression.Expression:
156+
return root.top_down(lambda x: replacements.get(x, x))
157+
158+
159+
def find_final_aggregations(
160+
root: expression.Expression,
161+
) -> Iterator[agg_expressions.Aggregation]:
162+
if isinstance(root, agg_expressions.Aggregation):
163+
yield root
164+
elif isinstance(root, expression.OpExpression):
165+
for child in root.children:
166+
yield from find_final_aggregations(child)
167+
elif isinstance(root, expression.ScalarConstantExpression):
168+
return
169+
else:
170+
# eg, window expression, column references not allowed
171+
raise ValueError(f"Unexpected node: {root}")
172+
173+
174+
def find_agg_inputs(
175+
root: agg_expressions.Aggregation,
176+
) -> Iterator[expression.Expression]:
177+
for child in root.children:
178+
if not isinstance(
179+
child, (expression.DerefOp, expression.ScalarConstantExpression)
180+
):
181+
yield child
182+
183+
27184
def gather_fragments(
28185
root: expression.Expression, fragmentized_children: Sequence[FactoredExpression]
29186
) -> FactoredExpression:

0 commit comments

Comments
 (0)