44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from executorch .exir .graph_module import get_control_flow_submodules
7+ from executorch .exir .graph_module import bfs_trace_with_node_process
88from executorch .exir .pass_base import ExportPass
99from torch .export import ExportedProgram
1010from torch .fx import GraphModule
1111from torch .fx .passes .infra .pass_base import PassResult
1212
13-
1413class DebugHandleGeneratorPass (ExportPass ):
1514 def call (self , graph_module : GraphModule ) -> PassResult :
1615 """Lower a quantized reference model (with reference quantized operator patterns)
1716 to executorch backend, that has a canonical set of quantized operators
1817 """
1918
20- queue = [graph_module ]
2119 index = 1
22- # bfs to traverse all modules including control flow submodules to attached debug handle id
23- while queue :
24- current_graph_module = queue .pop (0 )
25- for node in current_graph_module .graph .nodes :
26- node .meta ["debug_handle" ] = index
27- index += 1
28- control_flow_submodules = [
29- submodule
30- for _ , submodule , _ in get_control_flow_submodules (current_graph_module )
31- ]
32- queue .extend (control_flow_submodules )
20+
21+ def _extract_debug_handles_from_node (node ):
22+ nonlocal index
23+ node .meta ["debug_handle" ] = index
24+ index += 1
25+
26+ bfs_trace_with_node_process (graph_module , _extract_debug_handles_from_node )
27+
3328 return PassResult (graph_module , True )
3429
3530
@@ -38,28 +33,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3833 This pass is used to generate missing debug handles for the graph module and its submodules.
3934 """
4035
41- def get_control_flow_submodules_list (graph_module ):
42- return [
43- submodule for _ , submodule , _ in get_control_flow_submodules (graph_module )
44- ]
45-
4636 max_handle = 0
47- queue = [ep .graph_module ]
4837
49- while queue :
50- current_graph_module = queue .pop (0 )
51- for node in current_graph_module .graph .nodes :
52- if "debug_handle" in node .meta :
53- max_handle = max (max_handle , node .meta ["debug_handle" ])
54- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
55- queue .extend (control_flow_submodules )
38+ def _extract_max_debug_handle (node ):
39+ nonlocal max_handle
40+ if "debug_handle" in node .meta :
41+ max_handle = max (max_handle , node .meta ["debug_handle" ])
42+
43+ def _insert_new_debug_handles (node ):
44+ nonlocal max_handle
45+ if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
46+ node .meta ["debug_handle" ] = max_handle + 1
47+ max_handle += 1
5648
57- queue = [ep .graph_module ]
58- while queue :
59- current_graph_module = queue .pop (0 )
60- for node in current_graph_module .graph .nodes :
61- if node .meta .get ("debug_handle" , 0 ) in (0 , None ):
62- node .meta ["debug_handle" ] = max_handle + 1
63- max_handle += 1
64- control_flow_submodules = get_control_flow_submodules_list (current_graph_module )
65- queue .extend (control_flow_submodules )
49+ bfs_trace_with_node_process (ep .graph_module , _extract_max_debug_handle )
50+ bfs_trace_with_node_process (ep .graph_module , _insert_new_debug_handles )
0 commit comments