44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7- from  executorch .exir .graph_module  import  get_control_flow_submodules 
7+ from  executorch .exir .graph_module  import  bfs_trace_with_node_process 
88from  executorch .exir .pass_base  import  ExportPass 
99from  torch .export  import  ExportedProgram 
1010from  torch .fx  import  GraphModule 
@@ -17,19 +17,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
1717        to executorch backend, that has a canonical set of quantized operators 
1818        """ 
1919
20-         queue  =  [graph_module ]
2120        index  =  1 
22-         # bfs to traverse all modules including control flow submodules to attached debug handle id 
23-         while  queue :
24-             current_graph_module  =  queue .pop (0 )
25-             for  node  in  current_graph_module .graph .nodes :
26-                 node .meta ["debug_handle" ] =  index 
27-                 index  +=  1 
28-             control_flow_submodules  =  [
29-                 submodule 
30-                 for  _ , submodule , _  in  get_control_flow_submodules (current_graph_module )
31-             ]
32-             queue .extend (control_flow_submodules )
21+ 
22+         def  _extract_debug_handles_from_node (node ):
23+             nonlocal  index 
24+             node .meta ["debug_handle" ] =  index 
25+             index  +=  1 
26+ 
27+         bfs_trace_with_node_process (graph_module , _extract_debug_handles_from_node )
28+ 
3329        return  PassResult (graph_module , True )
3430
3531
@@ -38,28 +34,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3834    This pass is used to generate missing debug handles for the graph module and its submodules. 
3935    """ 
4036
41-     def  get_control_flow_submodules_list (graph_module ):
42-         return  [
43-             submodule  for  _ , submodule , _  in  get_control_flow_submodules (graph_module )
44-         ]
45- 
4637    max_handle  =  0 
47-     queue  =  [ep .graph_module ]
4838
49-     while  queue :
50-         current_graph_module  =  queue .pop (0 )
51-         for  node  in  current_graph_module .graph .nodes :
52-             if  "debug_handle"  in  node .meta :
53-                 max_handle  =  max (max_handle , node .meta ["debug_handle" ])
54-         control_flow_submodules  =  get_control_flow_submodules_list (current_graph_module )
55-         queue .extend (control_flow_submodules )
39+     def  _extract_max_debug_handle (node ):
40+         nonlocal  max_handle 
41+         if  "debug_handle"  in  node .meta :
42+             max_handle  =  max (max_handle , node .meta ["debug_handle" ])
43+ 
44+     def  _insert_new_debug_handles (node ):
45+         nonlocal  max_handle 
46+         if  node .meta .get ("debug_handle" , 0 ) in  (0 , None ):
47+             node .meta ["debug_handle" ] =  max_handle  +  1 
48+             max_handle  +=  1 
5649
57-     queue  =  [ep .graph_module ]
58-     while  queue :
59-         current_graph_module  =  queue .pop (0 )
60-         for  node  in  current_graph_module .graph .nodes :
61-             if  node .meta .get ("debug_handle" , 0 ) in  (0 , None ):
62-                 node .meta ["debug_handle" ] =  max_handle  +  1 
63-                 max_handle  +=  1 
64-         control_flow_submodules  =  get_control_flow_submodules_list (current_graph_module )
65-         queue .extend (control_flow_submodules )
50+     bfs_trace_with_node_process (ep .graph_module , _extract_max_debug_handle )
51+     bfs_trace_with_node_process (ep .graph_module , _insert_new_debug_handles )
0 commit comments