Skip to content

Commit c9a9921

Browse files
fix many issues
1 parent 4abc26c commit c9a9921

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

bigframes/core/array_value.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def compute_general_reduction(
295295
new_root = expression_factoring.plan_general_aggregation(
296296
plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
297297
)
298+
new_root.validate_tree()
298299
target_ids = tuple(named_expr.id for named_expr in named_exprs)
299300
return (ArrayValue(new_root), target_ids)
300301

bigframes/core/expression_factoring.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import itertools
55
from typing import (
66
cast,
7+
Dict,
78
Generic,
89
Hashable,
910
Iterable,
1011
Iterator,
1112
Mapping,
1213
Optional,
1314
Sequence,
15+
Set,
1416
Tuple,
1517
TypeVar,
1618
)
@@ -58,6 +60,7 @@ def plan_general_aggregation(
5860
tuple((cdef.expression, cdef.id) for cdef in all_aggs), # type: ignore
5961
by_column_ids=tuple(grouping_keys),
6062
)
63+
6164
post_scalar_exprs = tuple(
6265
(factored_agg.root_scalar_expr for factored_agg in factored_aggs)
6366
)
@@ -70,6 +73,7 @@ def plan_general_aggregation(
7073
plan = nodes.SelectionNode(
7174
plan, tuple(nodes.AliasedRef.identity(ident) for ident in final_ids)
7275
)
76+
7377
return plan
7478

7579

@@ -120,9 +124,7 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
120124
3. A final post-aggregate scalar expression
121125
"""
122126
final_aggs = set(find_final_aggregations(root.expression))
123-
agg_inputs = set(
124-
itertools.chain.from_iterable(map(find_final_aggregations, final_aggs))
125-
)
127+
agg_inputs = set(itertools.chain.from_iterable(map(find_agg_inputs, final_aggs)))
126128

127129
agg_input_defs = tuple(
128130
nodes.ColumnDef(expr, identifiers.ColumnId.unique()) for expr in agg_inputs
@@ -131,18 +133,18 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
131133
cdef.expression: expression.DerefOp(cdef.id) for cdef in agg_input_defs
132134
}
133135

136+
agg_expr_to_ids = {expr: identifiers.ColumnId.unique() for expr in final_aggs}
137+
134138
isolated_aggs = tuple(
135-
nodes.ColumnDef(
136-
sub_expressions(expr, agg_inputs_dict), identifiers.ColumnId.unique()
137-
)
138-
for expr in agg_inputs
139+
nodes.ColumnDef(sub_expressions(expr, agg_inputs_dict), agg_expr_to_ids[expr])
140+
for expr in final_aggs
139141
)
140142
agg_outputs_dict = {
141-
cdef.expression: expression.DerefOp(cdef.id) for cdef in isolated_aggs
143+
expr: expression.DerefOp(id) for expr, id in agg_expr_to_ids.items()
142144
}
143145

144146
root_scalar_expr = nodes.ColumnDef(
145-
sub_expressions(root.expression, agg_outputs_dict), root.id
147+
sub_expressions(root.expression, agg_outputs_dict), root.id # type: ignore
146148
)
147149

148150
return FactoredAggregation(
@@ -221,17 +223,23 @@ def replace_children(
221223

222224

223225
class DiGraph(Generic[T]):
224-
def __init__(self, edges: Iterable[Tuple[T, T]]):
225-
self._parents = collections.defaultdict(set)
226-
self._children = collections.defaultdict(set) # specifically, unpushed ones
226+
def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]):
227+
self._parents: Dict[T, Set[T]] = collections.defaultdict(set)
228+
self._children: Dict[T, Set[T]] = collections.defaultdict(
229+
set
230+
) # specifically, unpushed ones
227231
# use dict for stable ordering, which grants determinism
228232
self._sinks: dict[T, None] = dict()
233+
for node in nodes:
234+
self._children[node]
235+
self._parents[node]
236+
self._sinks[node] = None
229237
for src, dst in edges:
238+
assert src in self.nodes
239+
assert dst in self.nodes
230240
self._children[src].add(dst)
231241
self._parents[dst].add(src)
232242
# sinks have no children
233-
if not self._children[dst]:
234-
self._sinks[dst] = None
235243
if src in self._sinks:
236244
del self._sinks[src]
237245

@@ -249,9 +257,11 @@ def empty(self):
249257
return len(self.nodes) == 0
250258

251259
def parents(self, node: T) -> set[T]:
260+
assert node in self._parents
252261
return self._parents[node]
253262

254263
def children(self, node: T) -> set[T]:
264+
assert node in self._children
255265
return self._children[node]
256266

257267
def remove_node(self, node: T) -> None:
@@ -276,10 +286,13 @@ def push_into_tree(
276286
by_id = {expr.id: expr for expr in exprs}
277287
# id -> id
278288
graph = DiGraph(
279-
(expr.id, child_id)
280-
for expr in exprs
281-
for child_id in expr.expression.column_references
282-
if child_id in by_id.keys()
289+
(expr.id for expr in exprs),
290+
(
291+
(expr.id, child_id)
292+
for expr in exprs
293+
for child_id in expr.expression.column_references
294+
if child_id in by_id.keys()
295+
),
283296
)
284297
# TODO: Also prevent inlining expensive or non-deterministic
285298
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
@@ -354,13 +367,21 @@ def graph_extract_window_expr() -> Optional[
354367

355368
return None
356369

370+
must_be_pushed = set(target_ids) - set(graph.nodes)
371+
if not must_be_pushed.issubset(curr_root.ids):
372+
missing = must_be_pushed - set(curr_root.ids)
373+
raise ValueError(f"hmmm, missing {missing}")
374+
357375
while not graph.empty:
358376
pre_size = len(graph.nodes)
359377
scalar_exprs = graph_extract_scalar_exprs()
360378
if scalar_exprs:
361379
curr_root = nodes.ProjectionNode(
362380
curr_root, tuple((x.expression, x.id) for x in scalar_exprs)
363381
)
382+
must_be_pushed = set(target_ids) - set(graph.nodes)
383+
if not must_be_pushed.issubset(curr_root.ids):
384+
raise ValueError("hmmm")
364385
while result := graph_extract_window_expr():
365386
defs, window = result
366387
assert len(defs) > 0
@@ -369,6 +390,10 @@ def graph_extract_window_expr() -> Optional[
369390
tuple(defs),
370391
window,
371392
)
393+
must_be_pushed = set(target_ids) - set(graph.nodes)
394+
if not must_be_pushed.issubset(curr_root.ids):
395+
missing = must_be_pushed - set(curr_root.ids)
396+
raise ValueError(f"hmmm, missing {missing}")
372397
if len(graph.nodes) >= pre_size:
373398
raise ValueError("graph didn't shrink")
374399
# TODO: Try to get the ordering right earlier, so can avoid this extra node.

0 commit comments

Comments
 (0)