44# This source code is licensed under the BSD-style license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7- from  executorch .exir .graph_module  import  get_control_flow_submodules 
7+ from  executorch .exir .graph_module  import  bfs_trace_with_node_process 
88from  executorch .exir .pass_base  import  ExportPass 
99from  torch .export  import  ExportedProgram 
1010from  torch .fx  import  GraphModule 
1111from  torch .fx .passes .infra .pass_base  import  PassResult 
1212
13- 
1413class  DebugHandleGeneratorPass (ExportPass ):
1514    def  call (self , graph_module : GraphModule ) ->  PassResult :
1615        """Lower a quantized reference model (with reference quantized operator patterns) 
1716        to executorch backend, that has a canonical set of quantized operators 
1817        """ 
1918
20-         queue  =  [graph_module ]
2119        index  =  1 
22-         # bfs to traverse all modules including control flow submodules to attached debug handle id 
23-         while  queue :
24-             current_graph_module  =  queue .pop (0 )
25-             for  node  in  current_graph_module .graph .nodes :
26-                 node .meta ["debug_handle" ] =  index 
27-                 index  +=  1 
28-             control_flow_submodules  =  [
29-                 submodule 
30-                 for  _ , submodule , _  in  get_control_flow_submodules (current_graph_module )
31-             ]
32-             queue .extend (control_flow_submodules )
20+ 
21+         def  _extract_debug_handles_from_node (node ):
22+             nonlocal  index 
23+             node .meta ["debug_handle" ] =  index 
24+             index  +=  1 
25+ 
26+         bfs_trace_with_node_process (graph_module , _extract_debug_handles_from_node )
27+ 
3328        return  PassResult (graph_module , True )
3429
3530
@@ -38,28 +33,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3833    This pass is used to generate missing debug handles for the graph module and its submodules. 
3934    """ 
4035
41-     def  get_control_flow_submodules_list (graph_module ):
42-         return  [
43-             submodule  for  _ , submodule , _  in  get_control_flow_submodules (graph_module )
44-         ]
45- 
4636    max_handle  =  0 
47-     queue  =  [ep .graph_module ]
4837
49-     while  queue :
50-         current_graph_module  =  queue .pop (0 )
51-         for  node  in  current_graph_module .graph .nodes :
52-             if  "debug_handle"  in  node .meta :
53-                 max_handle  =  max (max_handle , node .meta ["debug_handle" ])
54-         control_flow_submodules  =  get_control_flow_submodules_list (current_graph_module )
55-         queue .extend (control_flow_submodules )
38+     def  _extract_max_debug_handle (node ):
39+         nonlocal  max_handle 
40+         if  "debug_handle"  in  node .meta :
41+             max_handle  =  max (max_handle , node .meta ["debug_handle" ])
42+ 
43+     def  _insert_new_debug_handles (node ):
44+         nonlocal  max_handle 
45+         if  node .meta .get ("debug_handle" , 0 ) in  (0 , None ):
46+             node .meta ["debug_handle" ] =  max_handle  +  1 
47+             max_handle  +=  1 
5648
57-     queue  =  [ep .graph_module ]
58-     while  queue :
59-         current_graph_module  =  queue .pop (0 )
60-         for  node  in  current_graph_module .graph .nodes :
61-             if  node .meta .get ("debug_handle" , 0 ) in  (0 , None ):
62-                 node .meta ["debug_handle" ] =  max_handle  +  1 
63-                 max_handle  +=  1 
64-         control_flow_submodules  =  get_control_flow_submodules_list (current_graph_module )
65-         queue .extend (control_flow_submodules )
49+     bfs_trace_with_node_process (ep .graph_module , _extract_max_debug_handle )
50+     bfs_trace_with_node_process (ep .graph_module , _insert_new_debug_handles )
0 commit comments