11import collections
22import dataclasses
33import functools
4- from typing import cast , Generic , Hashable , Iterable , Optional , Sequence , Tuple , TypeVar
4+ import itertools
5+ from typing import (
6+ cast ,
7+ Generic ,
8+ Hashable ,
9+ Iterable ,
10+ Iterator ,
11+ Mapping ,
12+ Optional ,
13+ Sequence ,
14+ Tuple ,
15+ TypeVar ,
16+ )
517
618from bigframes .core import agg_expressions , expression , identifiers , nodes , window_spec
719
820_MAX_INLINE_COMPLEXITY = 10
921
1022
23+ def plan_general_col_exprs (
24+ plan : nodes .BigFrameNode , col_exprs : Sequence [nodes .ColumnDef ]
25+ ) -> nodes .BigFrameNode :
26+ # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
27+ target_ids = tuple (named_expr .id for named_expr in col_exprs )
28+
29+ fragments = tuple (
30+ itertools .chain .from_iterable (
31+ fragmentize_expression (expr ) for expr in col_exprs
32+ )
33+ )
34+ return push_into_tree (plan , fragments , target_ids )
35+
36+
37+ def plan_general_aggregation (
38+ plan : nodes .BigFrameNode ,
39+ agg_defs : Sequence [nodes .ColumnDef ],
40+ grouping_keys : Sequence [expression .DerefOp ],
41+ ) -> nodes .BigFrameNode :
42+ factored_aggs = [factor_aggregation (agg_def ) for agg_def in agg_defs ]
43+ all_inputs = list (
44+ itertools .chain (* (factored_agg .agg_inputs for factored_agg in factored_aggs ))
45+ )
46+ # TODO: Windowize
47+ window_def = window_spec .WindowSpec (grouping_keys = tuple (grouping_keys ))
48+ windowized_inputs = [
49+ nodes .ColumnDef (windowize (cdef .expression , window_def ), cdef .id )
50+ for cdef in all_inputs
51+ ]
52+ plan = plan_general_col_exprs (plan , windowized_inputs )
53+ all_aggs = list (
54+ itertools .chain (* (factored_agg .agg_exprs for factored_agg in factored_aggs ))
55+ )
56+ plan = nodes .AggregateNode (
57+ plan ,
58+ tuple ((cdef .expression , cdef .id ) for cdef in all_aggs ), # type: ignore
59+ by_column_ids = tuple (grouping_keys ),
60+ )
61+ post_scalar_exprs = tuple (
62+ (factored_agg .root_scalar_expr for factored_agg in factored_aggs )
63+ )
64+ plan = nodes .ProjectionNode (
65+ plan , tuple ((cdef .expression , cdef .id ) for cdef in post_scalar_exprs )
66+ )
67+ plan = nodes .SelectionNode (
68+ plan , tuple (nodes .AliasedRef .identity (cdef .id ) for cdef in post_scalar_exprs )
69+ )
70+ return plan
71+
72+
1173@dataclasses .dataclass (frozen = True , eq = False )
1274class FactoredExpression :
1375 root_expr : expression .Expression
@@ -24,6 +86,101 @@ def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]:
2486 return (root_expr , * factored_expr .sub_exprs )
2587
2688
89+ @dataclasses .dataclass (frozen = True , eq = False )
90+ class FactoredAggregation :
91+ # pure scalar expression
92+ root_scalar_expr : nodes .ColumnDef
93+ # pure agg expression, only refs cols and consts
94+ agg_exprs : Tuple [nodes .ColumnDef , ...]
95+ # can be analytic, scalar op, const, col refs
96+ agg_inputs : Tuple [nodes .ColumnDef , ...]
97+
98+
99+ def windowize (
100+ root : expression .Expression , window : window_spec .WindowSpec
101+ ) -> expression .Expression :
102+ def windowize_local (expr : expression .Expression ):
103+ if isinstance (expr , agg_expressions .Aggregation ):
104+ return agg_expressions .WindowExpression (expr , window )
105+ if isinstance (expr , agg_expressions .WindowExpression ):
106+ raise ValueError (f"Expression { expr } already windowed!" )
107+ return expr
108+
109+ return root .bottom_up (windowize_local )
110+
111+
112+ def factor_aggregation (root : nodes .ColumnDef ) -> FactoredAggregation :
113+ """
114+ Factor an aggregation def into three components.
115+ 1. Input column expressions (includes analytic expressions)
116+ 2. The set of underlying primitive aggregations
117+ 3. A final post-aggregate scalar expression
118+ """
119+ final_aggs = set (find_final_aggregations (root .expression ))
120+ agg_inputs = set (
121+ itertools .chain .from_iterable (map (find_final_aggregations , final_aggs ))
122+ )
123+
124+ agg_input_defs = tuple (
125+ nodes .ColumnDef (expr , identifiers .ColumnId .unique ()) for expr in agg_inputs
126+ )
127+ agg_inputs_dict = {
128+ cdef .expression : expression .DerefOp (cdef .id ) for cdef in agg_input_defs
129+ }
130+
131+ isolated_aggs = tuple (
132+ nodes .ColumnDef (
133+ sub_expressions (expr , agg_inputs_dict ), identifiers .ColumnId .unique ()
134+ )
135+ for expr in agg_inputs
136+ )
137+ agg_outputs_dict = {
138+ cdef .expression : expression .DerefOp (cdef .id ) for cdef in isolated_aggs
139+ }
140+
141+ root_scalar_expr = nodes .ColumnDef (
142+ sub_expressions (root .expression , agg_outputs_dict ), root .id
143+ )
144+
145+ return FactoredAggregation (
146+ root_scalar_expr = root_scalar_expr ,
147+ agg_exprs = isolated_aggs ,
148+ agg_inputs = agg_input_defs ,
149+ )
150+
151+
152+ def sub_expressions (
153+ root : expression .Expression ,
154+ replacements : Mapping [expression .Expression , expression .Expression ],
155+ ) -> expression .Expression :
156+ return root .top_down (lambda x : replacements .get (x , x ))
157+
158+
159+ def find_final_aggregations (
160+ root : expression .Expression ,
161+ ) -> Iterator [agg_expressions .Aggregation ]:
162+ if isinstance (root , agg_expressions .Aggregation ):
163+ yield root
164+ elif isinstance (root , expression .OpExpression ):
165+ for child in root .children :
166+ yield from find_final_aggregations (child )
167+ elif isinstance (root , expression .ScalarConstantExpression ):
168+ return
169+ else :
170+ # eg, window expression, column references not allowed
171+ raise ValueError (f"Unexpected node: { root } " )
172+
173+
174+ def find_agg_inputs (
175+ root : agg_expressions .Aggregation ,
176+ ) -> Iterator [expression .Expression ]:
177+ for child in root .children :
178+ if not isinstance (
179+ child , (expression .DerefOp , expression .ScalarConstantExpression )
180+ ):
181+ yield child
182+
183+
27184def gather_fragments (
28185 root : expression .Expression , fragmentized_children : Sequence [FactoredExpression ]
29186) -> FactoredExpression :
0 commit comments