Skip to content

Commit 27bb639

Browse files
committed
restucture et debug handle
Pull Request resolved: #7197 This diff formats the debug handle generation process in et stack by. extracting bfs graph tracing process. ghstack-source-id: 257377901 @exported-using-ghexport Differential Revision: [D66622890](https://our.internmc.facebook.com/intern/diff/D66622890/)
1 parent de74961 commit 27bb639

File tree

2 files changed

+44
-37
lines changed

2 files changed

+44
-37
lines changed

exir/graph_module.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from types import FunctionType as function
10-
from typing import Dict, List, Tuple, Union
10+
from typing import Any, Callable, Dict, List, Tuple, Union
1111

1212
import torch
1313

@@ -68,3 +68,25 @@ def get_control_flow_submodules(
6868
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
6969

7070
return control_flow_submodules
71+
72+
# TODO(gasoonjia): remove this and leverage core pytorch bfs_trace_with_node_process after code freeze
73+
def bfs_trace_with_node_process(
74+
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
75+
) -> None:
76+
"""Traverse the graph module and apply node_op to each node."""
77+
78+
assert isinstance(
79+
gm, torch.fx.GraphModule
80+
), f"Expected GraphModule, got {type(gm)}"
81+
82+
queue = [gm]
83+
while queue:
84+
current_graph_module = queue.pop(0)
85+
for node in current_graph_module.graph.nodes:
86+
node_op(node)
87+
88+
control_flow_submodules = [
89+
submodule
90+
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
91+
]
92+
queue.extend(control_flow_submodules)

exir/passes/debug_handle_generator_pass.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,27 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from executorch.exir.graph_module import get_control_flow_submodules
7+
from executorch.exir.graph_module import bfs_trace_with_node_process
88
from executorch.exir.pass_base import ExportPass
99
from torch.export import ExportedProgram
1010
from torch.fx import GraphModule
1111
from torch.fx.passes.infra.pass_base import PassResult
1212

13-
1413
class DebugHandleGeneratorPass(ExportPass):
1514
def call(self, graph_module: GraphModule) -> PassResult:
1615
"""Lower a quantized reference model (with reference quantized operator patterns)
1716
to executorch backend, that has a canonical set of quantized operators
1817
"""
1918

20-
queue = [graph_module]
2119
index = 1
22-
# bfs to traverse all modules including control flow submodules to attached debug handle id
23-
while queue:
24-
current_graph_module = queue.pop(0)
25-
for node in current_graph_module.graph.nodes:
26-
node.meta["debug_handle"] = index
27-
index += 1
28-
control_flow_submodules = [
29-
submodule
30-
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
31-
]
32-
queue.extend(control_flow_submodules)
20+
21+
def _extract_debug_handles_from_node(node):
22+
nonlocal index
23+
node.meta["debug_handle"] = index
24+
index += 1
25+
26+
bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)
27+
3328
return PassResult(graph_module, True)
3429

3530

@@ -38,28 +33,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
3833
This pass is used to generate missing debug handles for the graph module and its submodules.
3934
"""
4035

41-
def get_control_flow_submodules_list(graph_module):
42-
return [
43-
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
44-
]
45-
4636
max_handle = 0
47-
queue = [ep.graph_module]
4837

49-
while queue:
50-
current_graph_module = queue.pop(0)
51-
for node in current_graph_module.graph.nodes:
52-
if "debug_handle" in node.meta:
53-
max_handle = max(max_handle, node.meta["debug_handle"])
54-
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
55-
queue.extend(control_flow_submodules)
38+
def _extract_max_debug_handle(node):
39+
nonlocal max_handle
40+
if "debug_handle" in node.meta:
41+
max_handle = max(max_handle, node.meta["debug_handle"])
42+
43+
def _insert_new_debug_handles(node):
44+
nonlocal max_handle
45+
if node.meta.get("debug_handle", 0) in (0, None):
46+
node.meta["debug_handle"] = max_handle + 1
47+
max_handle += 1
5648

57-
queue = [ep.graph_module]
58-
while queue:
59-
current_graph_module = queue.pop(0)
60-
for node in current_graph_module.graph.nodes:
61-
if node.meta.get("debug_handle", 0) in (0, None):
62-
node.meta["debug_handle"] = max_handle + 1
63-
max_handle += 1
64-
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
65-
queue.extend(control_flow_submodules)
49+
bfs_trace_with_node_process(ep.graph_module, _extract_max_debug_handle)
50+
bfs_trace_with_node_process(ep.graph_module, _insert_new_debug_handles)

0 commit comments

Comments
 (0)