44import itertools
55from typing import (
66 cast ,
7+ Dict ,
78 Generic ,
89 Hashable ,
910 Iterable ,
1011 Iterator ,
1112 Mapping ,
1213 Optional ,
1314 Sequence ,
15+ Set ,
1416 Tuple ,
1517 TypeVar ,
1618)
@@ -58,6 +60,7 @@ def plan_general_aggregation(
5860 tuple ((cdef .expression , cdef .id ) for cdef in all_aggs ), # type: ignore
5961 by_column_ids = tuple (grouping_keys ),
6062 )
63+
6164 post_scalar_exprs = tuple (
6265 (factored_agg .root_scalar_expr for factored_agg in factored_aggs )
6366 )
@@ -70,6 +73,7 @@ def plan_general_aggregation(
7073 plan = nodes .SelectionNode (
7174 plan , tuple (nodes .AliasedRef .identity (ident ) for ident in final_ids )
7275 )
76+
7377 return plan
7478
7579
@@ -120,9 +124,7 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
120124 3. A final post-aggregate scalar expression
121125 """
122126 final_aggs = set (find_final_aggregations (root .expression ))
123- agg_inputs = set (
124- itertools .chain .from_iterable (map (find_final_aggregations , final_aggs ))
125- )
127+ agg_inputs = set (itertools .chain .from_iterable (map (find_agg_inputs , final_aggs )))
126128
127129 agg_input_defs = tuple (
128130 nodes .ColumnDef (expr , identifiers .ColumnId .unique ()) for expr in agg_inputs
@@ -131,18 +133,18 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
131133 cdef .expression : expression .DerefOp (cdef .id ) for cdef in agg_input_defs
132134 }
133135
136+ agg_expr_to_ids = {expr : identifiers .ColumnId .unique () for expr in final_aggs }
137+
134138 isolated_aggs = tuple (
135- nodes .ColumnDef (
136- sub_expressions (expr , agg_inputs_dict ), identifiers .ColumnId .unique ()
137- )
138- for expr in agg_inputs
139+ nodes .ColumnDef (sub_expressions (expr , agg_inputs_dict ), agg_expr_to_ids [expr ])
140+ for expr in final_aggs
139141 )
140142 agg_outputs_dict = {
141- cdef . expression : expression .DerefOp (cdef . id ) for cdef in isolated_aggs
143+ expr : expression .DerefOp (id ) for expr , id in agg_expr_to_ids . items ()
142144 }
143145
144146 root_scalar_expr = nodes .ColumnDef (
145- sub_expressions (root .expression , agg_outputs_dict ), root .id
147+ sub_expressions (root .expression , agg_outputs_dict ), root .id # type: ignore
146148 )
147149
148150 return FactoredAggregation (
@@ -221,17 +223,23 @@ def replace_children(
221223
222224
223225class DiGraph (Generic [T ]):
224- def __init__ (self , edges : Iterable [Tuple [T , T ]]):
225- self ._parents = collections .defaultdict (set )
226- self ._children = collections .defaultdict (set ) # specifically, unpushed ones
226+ def __init__ (self , nodes : Iterable [T ], edges : Iterable [Tuple [T , T ]]):
227+ self ._parents : Dict [T , Set [T ]] = collections .defaultdict (set )
228+ self ._children : Dict [T , Set [T ]] = collections .defaultdict (
229+ set
230+ ) # specifically, unpushed ones
227231 # use dict for stable ordering, which grants determinism
228232 self ._sinks : dict [T , None ] = dict ()
233+ for node in nodes :
234+ self ._children [node ]
235+ self ._parents [node ]
236+ self ._sinks [node ] = None
229237 for src , dst in edges :
238+ assert src in self .nodes
239+ assert dst in self .nodes
230240 self ._children [src ].add (dst )
231241 self ._parents [dst ].add (src )
232242 # sinks have no children
233- if not self ._children [dst ]:
234- self ._sinks [dst ] = None
235243 if src in self ._sinks :
236244 del self ._sinks [src ]
237245
@@ -249,9 +257,11 @@ def empty(self):
249257 return len (self .nodes ) == 0
250258
251259 def parents (self , node : T ) -> set [T ]:
260+ assert node in self ._parents
252261 return self ._parents [node ]
253262
254263 def children (self , node : T ) -> set [T ]:
264+ assert node in self ._children
255265 return self ._children [node ]
256266
257267 def remove_node (self , node : T ) -> None :
@@ -276,10 +286,13 @@ def push_into_tree(
276286 by_id = {expr .id : expr for expr in exprs }
277287 # id -> id
278288 graph = DiGraph (
279- (expr .id , child_id )
280- for expr in exprs
281- for child_id in expr .expression .column_references
282- if child_id in by_id .keys ()
289+ (expr .id for expr in exprs ),
290+ (
291+ (expr .id , child_id )
292+ for expr in exprs
293+ for child_id in expr .expression .column_references
294+ if child_id in by_id .keys ()
295+ ),
283296 )
284297 # TODO: Also prevent inlining expensive or non-deterministic
285298 # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
@@ -354,13 +367,21 @@ def graph_extract_window_expr() -> Optional[
354367
355368 return None
356369
370+ must_be_pushed = set (target_ids ) - set (graph .nodes )
371+ if not must_be_pushed .issubset (curr_root .ids ):
372+ missing = must_be_pushed - set (curr_root .ids )
373+ raise ValueError (f"hmmm, missing { missing } " )
374+
357375 while not graph .empty :
358376 pre_size = len (graph .nodes )
359377 scalar_exprs = graph_extract_scalar_exprs ()
360378 if scalar_exprs :
361379 curr_root = nodes .ProjectionNode (
362380 curr_root , tuple ((x .expression , x .id ) for x in scalar_exprs )
363381 )
382+ must_be_pushed = set (target_ids ) - set (graph .nodes )
383+ if not must_be_pushed .issubset (curr_root .ids ):
384+ raise ValueError ("hmmm" )
364385 while result := graph_extract_window_expr ():
365386 defs , window = result
366387 assert len (defs ) > 0
@@ -369,6 +390,10 @@ def graph_extract_window_expr() -> Optional[
369390 tuple (defs ),
370391 window ,
371392 )
393+ must_be_pushed = set (target_ids ) - set (graph .nodes )
394+ if not must_be_pushed .issubset (curr_root .ids ):
395+ missing = must_be_pushed - set (curr_root .ids )
396+ raise ValueError (f"hmmm, missing { missing } " )
372397 if len (graph .nodes ) >= pre_size :
373398 raise ValueError ("graph didn't shrink" )
374399 # TODO: Try to get the ordering right earlier, so can avoid this extra node.
0 commit comments