20
20
import functools
21
21
import itertools
22
22
import typing
23
- from typing import (
24
- Callable ,
25
- Dict ,
26
- Generator ,
27
- Iterable ,
28
- Mapping ,
29
- Sequence ,
30
- Set ,
31
- Tuple ,
32
- Union ,
33
- )
23
+ from typing import Callable , Dict , Generator , Iterable , Mapping , Sequence , Tuple , Union
34
24
35
25
from bigframes .core import expression , field , identifiers
36
26
import bigframes .core .schema as schemata
@@ -309,33 +299,31 @@ def unique_nodes(
309
299
seen .add (item )
310
300
stack .extend (item .child_nodes )
311
301
312
- def edges (
302
+ def iter_nodes_topo (
313
303
self : BigFrameNode ,
314
- ) -> Generator [Tuple [BigFrameNode , BigFrameNode ], None , None ]:
315
- for item in self .unique_nodes ():
316
- for child in item .child_nodes :
317
- yield (item , child )
318
-
319
- def iter_nodes_topo (self : BigFrameNode ) -> Generator [BigFrameNode , None , None ]:
320
- """Returns nodes from bottom up."""
321
- queue = collections .deque (
322
- [node for node in self .unique_nodes () if not node .child_nodes ]
323
- )
324
-
304
+ ) -> Generator [Tuple [BigFrameNode , Sequence [BigFrameNode ]], None , None ]:
305
+ """Returns nodes in reverse topological order, using Kahn's algorithm."""
325
306
child_to_parents : Dict [
326
- BigFrameNode , Set [BigFrameNode ]
327
- ] = collections .defaultdict (set )
328
- for parent , child in self .edges ():
329
- child_to_parents [child ].add (parent )
330
-
331
- yielded = set ()
307
+ BigFrameNode , list [BigFrameNode ]
308
+ ] = collections .defaultdict (list )
309
+ out_degree : Dict [BigFrameNode , int ] = collections .defaultdict (int )
310
+
311
+ queue : collections .deque ["BigFrameNode" ] = collections .deque ()
312
+ for node in list (self .unique_nodes ()):
313
+ num_children = len (node .child_nodes )
314
+ out_degree [node ] = num_children
315
+ if num_children == 0 :
316
+ queue .append (node )
317
+ for child in node .child_nodes :
318
+ child_to_parents [child ].append (node )
332
319
333
320
while queue :
334
321
item = queue .popleft ()
335
- yield item
336
- yielded .add (item )
337
- for parent in child_to_parents [item ]:
338
- if set (parent .child_nodes ).issubset (yielded ):
322
+ parents = child_to_parents .get (item , [])
323
+ yield item , parents
324
+ for parent in parents :
325
+ out_degree [parent ] -= 1
326
+ if out_degree [parent ] == 0 :
339
327
queue .append (parent )
340
328
341
329
def top_down (
@@ -376,7 +364,7 @@ def bottom_up(
376
364
Returns the transformed root node.
377
365
"""
378
366
results : dict [BigFrameNode , BigFrameNode ] = {}
379
- for node in list (self .iter_nodes_topo ()):
367
+ for node , _ in list (self .iter_nodes_topo ()):
380
368
# child nodes have already been transformed
381
369
result = node .transform_children (lambda x : results [x ])
382
370
result = transform (result )
@@ -387,7 +375,7 @@ def bottom_up(
387
375
def reduce_up (self , reduction : Callable [[BigFrameNode , Tuple [T , ...]], T ]) -> T :
388
376
"""Apply a bottom-up reduction to the tree."""
389
377
results : dict [BigFrameNode , T ] = {}
390
- for node in list (self .iter_nodes_topo ()):
378
+ for node , _ in list (self .iter_nodes_topo ()):
391
379
# child nodes have already been transformed
392
380
child_results = tuple (results [child ] for child in node .child_nodes )
393
381
result = reduction (node , child_results )
0 commit comments