@@ -151,6 +151,9 @@ def __init__(self, origin: Node) -> None:
151151 def create (origin : Node ) -> "BackwardBfsArgIter" :
152152 it = BackwardBfsArgIter (origin )
153153 it .add_children (origin )
154+ # pop the origin node, since it is the origin of
155+ # the region and does not need to be considered for addition
156+ assert it .next ()
154157 return it
155158
156159 def next (self ) -> Optional [Node ]:
@@ -165,17 +168,11 @@ def peek(self) -> Optional[Node]:
165168 return self ._cur
166169
167170 def add_children (self , node : Node ) -> None :
168- arg : Any
169- flat_args , _ = tree_flatten (node .args )
171+ flat_args = _get_flat_args_unique (node , {})
170172 for arg in flat_args :
171173 if isinstance (arg , Node ):
172174 self ._append (arg )
173175
174- flat_kwargs , _ = tree_flatten (node .kwargs )
175- for kwarg in flat_kwargs :
176- if isinstance (kwarg , Node ):
177- self ._append (kwarg )
178-
179176 def _append (self , arg : Node ) -> None :
180177 if self ._cur is None :
181178 self ._cur = arg
@@ -328,6 +325,38 @@ def __str__(self) -> str:
328325 return f"GraphRegionTracker(hash_to_duplicates={ self .hash_to_duplicates } , node_to_duplicates={ self .node_to_duplicates } )"
329326
330327
328+ class RegionWrapper :
329+ """Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
330+
331+ def __init__ (
332+ self , region : Region , node_to_recursive_ancestors : dict [Node , set [Node ]]
333+ ) -> None :
334+ assert len (region ) == 1 , "all regions should start with one node"
335+ node = region [0 ]
336+ self .node_to_recursive_ancestors = node_to_recursive_ancestors
337+ self .iter = BackwardBfsArgIter .create (node )
338+ self .nodes_unique = OrderedSet ([node ])
339+ self .ancestors = set (node_to_recursive_ancestors [node ])
340+ self .region = region
341+
342+ def next_candidate (self ) -> Optional [Node ]:
343+ return self .iter .next ()
344+
345+ def will_inclusion_create_cycle (self , node : Node ) -> bool :
346+ external_users = [user for user in node .users if user not in self .nodes_unique ]
347+ for user in external_users :
348+ if user in self .ancestors :
349+ return True
350+
351+ return False
352+
353+ def add (self , node : Node ) -> None :
354+ self .nodes_unique .add (node )
355+ self .region .append (node )
356+ self .iter .add_children (node )
357+ self .ancestors .update (self .node_to_recursive_ancestors [node ])
358+
359+
331360def fully_expand_region_group (
332361 regions : list [Region ],
333362 seen_nodes : set [Node ],
@@ -339,20 +368,12 @@ def fully_expand_region_group(
339368
340369 # All regions should start with 1 node
341370 assert all (len (region ) == 1 for region in regions )
342- region_iters = []
343- for region in regions :
344- (origin ,) = region # Only works for 1 element sets
345- region_iters .append (BackwardBfsArgIter .create (origin ))
346-
347- nodes_to_add : list [Node ] = []
348-
349- # we already have the origin node in each region
350- for region_it in region_iters :
351- node = region_it .next ()
352- assert node
353- region_it .add_children (node )
371+ region_wrappers = [
372+ RegionWrapper (region , node_to_recursive_ancestors ) for region in regions
373+ ]
354374
355- current_node = region_iters [0 ].next ()
375+ nodes_to_add = OrderedSet [Node ]()
376+ current_node = region_wrappers [0 ].next_candidate ()
356377
357378 # No children
358379 if current_node is None :
@@ -362,46 +383,51 @@ def fully_expand_region_group(
362383 # regions are only expanded if the node to add is valid
363384 # for ALL regions
364385 while current_node :
365- add_node = not _will_create_cycle (
366- current_node , regions [ 0 ], node_to_recursive_ancestors
386+ add_to_all_regions = not region_wrappers [ 0 ]. will_inclusion_create_cycle (
387+ current_node
367388 )
368389 nodes_to_add .clear ()
369- nodes_to_add .append (current_node )
370- nodes_to_add_set = set (nodes_to_add )
371- for ind , region_it in enumerate (region_iters [1 :]):
372- ind += 1 # compensate for the 0th region
373- node = region_it .next ()
390+ nodes_to_add .add (current_node )
391+ for region_wrapper in region_wrappers [1 :]:
392+ candidate = region_wrapper .next_candidate ()
374393
375394 debug_log ("--------------------" )
376- debug_log ("considering adding: %s, cur_node: %s" , node , current_node )
377- debug_log ("previously claimed nodes: %s" , node in seen_nodes )
378- if node :
379- debug_log ("is_identical: %s" , is_identical_fn (node , current_node ))
380- add_node &= (
381- node not in seen_nodes
382- and node not in nodes_to_add_set
383- and node .op != "placeholder"
384- and is_identical_fn (node , current_node )
385- and not _will_create_cycle (
386- node , regions [ind ], node_to_recursive_ancestors
387- )
388- )
389- nodes_to_add .append (node )
390- nodes_to_add_set .add (node )
391- else :
392- add_node = False
395+ debug_log (
396+ "considering candidate: %s, cur_node: %s" , candidate , current_node
397+ )
398+
399+ if not candidate or not add_to_all_regions :
400+ add_to_all_regions = False
401+ continue
402+
403+ debug_log (
404+ "candidate in previously claimed nodes?: %s" , candidate in seen_nodes
405+ )
406+ debug_log ("is_identical: %s" , is_identical_fn (candidate , current_node ))
407+
408+ add_to_all_regions &= (
409+ candidate not in seen_nodes
410+ and candidate not in nodes_to_add
411+ and candidate .op != "placeholder"
412+ and is_identical_fn (candidate , current_node )
413+ and not region_wrapper .will_inclusion_create_cycle (candidate )
414+ )
415+ nodes_to_add .add (candidate )
393416
417+ debug_log (f"add_to_all_regions: { add_to_all_regions } " )
394418 debug_log ("--------------------" )
395419
396- if add_node :
397- for region , region_it , node in zip (regions , region_iters , nodes_to_add ):
398- region .append (node )
420+ if add_to_all_regions :
421+ assert len (region_wrappers ) == len (nodes_to_add ), (
422+ "Numer of nodes to add must equal the number of regions"
423+ )
424+ for region_wrapper , node in zip (region_wrappers , nodes_to_add ):
425+ region_wrapper .add (node )
399426 debug_log ("adding %s's children" , node )
400427 debug_log ("%s %s" , node .args , list (node .kwargs .items ()))
401- region_it .add_children (node )
402428 seen_nodes .add (node )
403429
404- current_node = region_iters [0 ].next ()
430+ current_node = region_wrappers [0 ].next_candidate ()
405431
406432 # Ensure regions are sorted in topological order
407433 for region in regions :
@@ -424,20 +450,3 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
424450 )
425451 node_to_recursive_ancestors [node ].add (arg )
426452 return node_to_recursive_ancestors
427-
428-
429- def _will_create_cycle (
430- node_to_add : Node ,
431- region : Region ,
432- node_to_recursive_ancestors : dict [Node , set [Node ]],
433- ) -> bool :
434- region_set : set [Node ] = set (region )
435- region_ancestors : set [Node ] = set (
436- tree_flatten ([list (node_to_recursive_ancestors [node ]) for node in region ])[0 ]
437- )
438- external_users = [user for user in node_to_add .users if user not in region_set ]
439- for user in external_users :
440- if user in region_ancestors :
441- return True
442-
443- return False
0 commit comments