diff --git a/exir/graph_module.py b/exir/graph_module.py index 7a032b5290d..e26d22d8145 100644 --- a/exir/graph_module.py +++ b/exir/graph_module.py @@ -7,7 +7,7 @@ # pyre-strict from types import FunctionType as function -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import torch @@ -68,3 +68,23 @@ def get_control_flow_submodules( control_flow_submodules.append(_get_submodule(graph_module, node, 0)) return control_flow_submodules + + +def bfs_trace_with_node_process( + gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None] +) -> None: + """Traverse the graph module and apply node_op to each node.""" + + assert isinstance(gm, torch.fx.GraphModule), f"Expected GraphModule, got {type(gm)}" + + queue = [gm] + while queue: + current_graph_module = queue.pop(0) + for node in current_graph_module.graph.nodes: + node_op(node) + + control_flow_submodules = [ + submodule + for _, submodule, _ in get_control_flow_submodules(current_graph_module) + ] + queue.extend(control_flow_submodules) diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 0502c47dbb3..7de8676084b 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram from torch.fx import GraphModule @@ -17,19 +17,15 @@ def call(self, graph_module: GraphModule) -> PassResult: to executorch backend, that has a canonical set of quantized operators """ - queue = [graph_module] index = 1 - # bfs to traverse all modules including control flow submodules to attached debug handle id - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - node.meta["debug_handle"] = index - index += 1 - control_flow_submodules = [ - submodule - for _, submodule, _ in get_control_flow_submodules(current_graph_module) - ] - queue.extend(control_flow_submodules) + + def _extract_debug_handles_from_node(node): + nonlocal index + node.meta["debug_handle"] = index + index += 1 + + bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node) + return PassResult(graph_module, True) @@ -38,28 +34,18 @@ def generate_missing_debug_handles(ep: ExportedProgram): This pass is used to generate missing debug handles for the graph module and its submodules. """ - def get_control_flow_submodules_list(graph_module): - return [ - submodule for _, submodule, _ in get_control_flow_submodules(graph_module) - ] - max_handle = 0 - queue = [ep.graph_module] - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - if "debug_handle" in node.meta: - max_handle = max(max_handle, node.meta["debug_handle"]) - control_flow_submodules = get_control_flow_submodules_list(current_graph_module) - queue.extend(control_flow_submodules) + def _extract_max_debug_handle(node): + nonlocal max_handle + if "debug_handle" in node.meta: + max_handle = max(max_handle, node.meta["debug_handle"]) + + def _insert_new_debug_handles(node): + nonlocal max_handle + if node.meta.get("debug_handle", 0) in (0, None): + node.meta["debug_handle"] = max_handle + 1 + max_handle += 1 - queue = [ep.graph_module] - while queue: - current_graph_module = queue.pop(0) - for node in current_graph_module.graph.nodes: - if node.meta.get("debug_handle", 0) in (0, None): - node.meta["debug_handle"] = max_handle + 1 - max_handle += 1 - control_flow_submodules = get_control_flow_submodules_list(current_graph_module) - queue.extend(control_flow_submodules) + bfs_trace_with_node_process(ep.graph_module, _extract_max_debug_handle) + bfs_trace_with_node_process(ep.graph_module, _insert_new_debug_handles)