diff --git a/bigframes/core/bigframe_node.py b/bigframes/core/bigframe_node.py index 9054ab9ba0..c8985e2c93 100644 --- a/bigframes/core/bigframe_node.py +++ b/bigframes/core/bigframe_node.py @@ -20,17 +20,7 @@ import functools import itertools import typing -from typing import ( - Callable, - Dict, - Generator, - Iterable, - Mapping, - Sequence, - Set, - Tuple, - Union, -) +from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union from bigframes.core import expression, field, identifiers import bigframes.core.schema as schemata @@ -309,33 +299,31 @@ def unique_nodes( seen.add(item) stack.extend(item.child_nodes) - def edges( + def iter_nodes_topo( self: BigFrameNode, - ) -> Generator[Tuple[BigFrameNode, BigFrameNode], None, None]: - for item in self.unique_nodes(): - for child in item.child_nodes: - yield (item, child) - - def iter_nodes_topo(self: BigFrameNode) -> Generator[BigFrameNode, None, None]: - """Returns nodes from bottom up.""" - queue = collections.deque( - [node for node in self.unique_nodes() if not node.child_nodes] - ) - + ) -> Generator[Tuple[BigFrameNode, Sequence[BigFrameNode]], None, None]: + """Returns nodes in topological order, using Kahn's algorithm.""" child_to_parents: Dict[ - BigFrameNode, Set[BigFrameNode] - ] = collections.defaultdict(set) - for parent, child in self.edges(): - child_to_parents[child].add(parent) - - yielded = set() + BigFrameNode, list[BigFrameNode] + ] = collections.defaultdict(list) + out_degree: Dict[BigFrameNode, int] = collections.defaultdict(int) + + queue: collections.deque["BigFrameNode"] = collections.deque() + for node in list(self.unique_nodes()): + num_children = len(node.child_nodes) + out_degree[node] = num_children + if num_children == 0: + queue.append(node) + for child in node.child_nodes: + child_to_parents[child].append(node) while queue: item = queue.popleft() - yield item - yielded.add(item) - for parent in child_to_parents[item]: - if set(parent.child_nodes).issubset(yielded): + parents = child_to_parents.get(item, []) + yield item, parents + for parent in parents: + out_degree[parent] -= 1 + if out_degree[parent] == 0: queue.append(parent) def top_down( @@ -376,7 +364,7 @@ def bottom_up( Returns the transformed root node. """ results: dict[BigFrameNode, BigFrameNode] = {} - for node in list(self.iter_nodes_topo()): + for node, _ in list(self.iter_nodes_topo()): # child nodes have already been transformed result = node.transform_children(lambda x: results[x]) result = transform(result) @@ -387,7 +375,7 @@ def bottom_up( def reduce_up(self, reduction: Callable[[BigFrameNode, Tuple[T, ...]], T]) -> T: """Apply a bottom-up reduction to the tree.""" results: dict[BigFrameNode, T] = {} - for node in list(self.iter_nodes_topo()): + for node, _ in list(self.iter_nodes_topo()): # child nodes have already been transformed child_results = tuple(results[child] for child in node.child_nodes) result = reduction(node, child_results) diff --git a/tests/unit/core/compile/sqlglot/test_compile_concat.py b/tests/unit/core/compile/sqlglot/test_compile_concat.py index ec7e83a4b0..d223ad967c 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_concat.py +++ b/tests/unit/core/compile/sqlglot/test_compile_concat.py @@ -24,7 +24,6 @@ def test_compile_concat( scalars_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot ): - # TODO: concat two same dataframes, which SQL does not get reused. # TODO: concat dataframes from a gbq table but trigger a windows compiler. df1 = bpd.DataFrame(scalars_types_pandas_df, session=compiler_session) df1 = df1[["rowindex", "int64_col", "string_col"]]