diff --git a/bigframes/core/bigframe_node.py b/bigframes/core/bigframe_node.py index 9054ab9ba0..0c6f56f35a 100644 --- a/bigframes/core/bigframe_node.py +++ b/bigframes/core/bigframe_node.py @@ -20,17 +20,7 @@ import functools import itertools import typing -from typing import ( - Callable, - Dict, - Generator, - Iterable, - Mapping, - Sequence, - Set, - Tuple, - Union, -) +from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union from bigframes.core import expression, field, identifiers import bigframes.core.schema as schemata @@ -309,33 +299,31 @@ def unique_nodes( seen.add(item) stack.extend(item.child_nodes) - def edges( + def iter_nodes_topo( self: BigFrameNode, - ) -> Generator[Tuple[BigFrameNode, BigFrameNode], None, None]: - for item in self.unique_nodes(): - for child in item.child_nodes: - yield (item, child) - - def iter_nodes_topo(self: BigFrameNode) -> Generator[BigFrameNode, None, None]: - """Returns nodes from bottom up.""" - queue = collections.deque( - [node for node in self.unique_nodes() if not node.child_nodes] - ) - + ) -> Generator[BigFrameNode, None, None]: + """Returns nodes in reverse topological order, using Kahn's algorithm.""" child_to_parents: Dict[ - BigFrameNode, Set[BigFrameNode] - ] = collections.defaultdict(set) - for parent, child in self.edges(): - child_to_parents[child].add(parent) - - yielded = set() + BigFrameNode, list[BigFrameNode] + ] = collections.defaultdict(list) + out_degree: Dict[BigFrameNode, int] = collections.defaultdict(int) + + queue: collections.deque["BigFrameNode"] = collections.deque() + for node in list(self.unique_nodes()): + num_children = len(node.child_nodes) + out_degree[node] = num_children + if num_children == 0: + queue.append(node) + for child in node.child_nodes: + child_to_parents[child].append(node) while queue: item = queue.popleft() yield item - yielded.add(item) - for parent in child_to_parents[item]: - if set(parent.child_nodes).issubset(yielded): + parents = child_to_parents.get(item, []) + for parent in parents: + out_degree[parent] -= 1 + if out_degree[parent] == 0: queue.append(parent) def top_down(