From 57749956e9acfb62dcb52f828383f127f5676bb9 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 9 Jul 2025 21:57:22 -0700 Subject: [PATCH] propagate debug handle from edge dialect graph back to exported graph Using exported graph from torch.export as source of truth for aot intermediate output is our target. Once thing blocking us is the exported graph does not have debug handle, which will not show up in the export flow until DebugHandleGenerationPass, the last step of to_edge(). We need to equip the graph with same debug handle used in ExecuTorch flow. This diff creates a propagate_back_debug_handle function, which propagate debug handle from edge dialect program back to the exported program while maintain the correctness of operator tracing. Differential Revision: [D78051614](https://our.internmc.facebook.com/intern/diff/D78051614/) [ghstack-poisoned] --- devtools/inspector/_inspector_utils.py | 75 ++++++++++++ .../inspector/tests/inspector_utils_test.py | 114 +++++++++++++++++- exir/TARGETS | 8 ++ exir/debug_handle_utils.py | 27 +++++ exir/passes/TARGETS | 1 + exir/passes/debug_handle_generator_pass.py | 26 +--- 6 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 exir/debug_handle_utils.py diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 32a46ab0276..7e09bb6cb55 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -35,8 +35,17 @@ from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc from executorch.devtools.etrecord import ETRecord +from executorch.exir.debug_handle_utils import ( + DEBUG_HANDLE_KEY, + get_greatest_ancestor_node_identifier, +) + +from executorch.exir.graph_module import bfs_trace_with_node_process + from tabulate import tabulate +from torch.export import ExportedProgram + FORWARD = "forward" EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -888,3 +897,69 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: else: # Raise an error if one is a sequence and the other is not raise ValueError("Both inputs must be sequences or both must be non-sequences.") + + +def propagate_back_debug_handle( + exported_program: ExportedProgram, + exported_program_graph_id: int, + edge_dialect_program: ExportedProgram, +) -> bool: + """ + Propagate debug handle from edge dialect program back to the exported program while maintain the correctness + of operator tracing. + + e.g. + export program: op1 -> op2 -> op3 + edge dialect program: op1_0 -> op3_0 -> op3_1 + where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass). + + Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. + The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping. + + Return: True if: + a. every debug handle in the edge dialect program has a corresponding node in the exported program + b. the exported program is the greatest ancestor of the edge dialect program + + Otherwise, return False. + """ + + # 1. set up a mapping from debug handle to identifier of export program's node + # using edge dialect program nodes' debug handles and from_node info + export_graph_node_id_to_debug_handle = { + get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY] + for node in edge_dialect_program.graph.nodes + if node.op not in ("placeholder", "output") + } + + # 2. equip debug handle to the exported program's nodes using the mapping + # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle + n_matched_node = 0 + + # debug handle for the node in the exported program but not in the edge dialect program + debug_handle_for_removed_node = ( + max(export_graph_node_id_to_debug_handle.values()) + 1 + ) + + def _find_n_match_node(node: torch.fx.Node) -> None: + nonlocal n_matched_node + if node.name in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + if node_id in export_graph_node_id_to_debug_handle: + n_matched_node += 1 + + def _equip_debug_handle(node: torch.fx.Node) -> None: + if node.name in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + if node_id in export_graph_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] + else: + node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node + + bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + if n_matched_node != len(export_graph_node_id_to_debug_handle): + return False + + bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + return True diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 47113910e98..3350f4cce77 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -10,8 +10,9 @@ import unittest from typing import Dict, Tuple -import torch +import executorch.exir.tests.models as models +import torch from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.base_schema import ( @@ -41,9 +42,13 @@ map_runtime_aot_intermediate_outputs, merge_runtime_overlapping_debug_handles, NodeFilter, + propagate_back_debug_handle, TimeScale, ) from executorch.devtools.inspector.numerical_comparator import L1Comparator +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from torch.export import export class TestInspectorUtils(unittest.TestCase): @@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self): with self.assertRaises(ValueError): compare_intermediate_outputs(a, b, L1Comparator()) + def test_equip_debug_handle_to_export_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + exported_program = export(model, inputs) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + + def test_equip_debug_handle_to_export_program_failure(self): + """Test that propagate_back_debug_handle returns False when there's a mismatch.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + exported_program = export(model, inputs) + edge_dialect_program = to_edge(exported_program).exported_program() + + # Create a different exported program (reexport) to cause mismatch + reexported_program = export(model, inputs) + reexport_graph_id = id(reexported_program.graph) + + # Call propagate_back_debug_handle with mismatched programs + # This should return False because the reexported program has different node identifiers + result = propagate_back_debug_handle( + reexported_program, reexport_graph_id, edge_dialect_program + ) + + # Check that it returns False due to mismatch + self.assertFalse(result) + + def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge""" + + class M(torch.nn.Module): + """ + Simple model with ops that will be removed in to_edge + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + x = x.to(x.dtype) + x = x + 1 + return x + + inputs = (torch.rand(5, 5),) + exported_program = torch.export.export(M(), inputs) + export_graph_id = id(exported_program.graph) + edge_dialect_program = to_edge(exported_program).exported_program() + + self.assertTrue( + propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + ) + + # only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three + debug_handle_for_removed_node = 3 + + for node in exported_program.graph.nodes: + if node.name == "add": + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1) + elif node.name == "add_1": + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2) + elif node.op not in ("placeholder", "output"): + self.assertEqual( + node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node + ) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/exir/TARGETS b/exir/TARGETS index 7916cec29fb..cda57de7f80 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -277,3 +277,11 @@ python_library( "fbsource//third-party/pypi/typing-extensions:typing-extensions", ], ) + +python_library( + name = "debug_handle_utils", + srcs = ["debug_handle_utils.py"], + deps = [ + "//caffe2:torch", + ], +) diff --git a/exir/debug_handle_utils.py b/exir/debug_handle_utils.py new file mode 100644 index 00000000000..d1a70fcd213 --- /dev/null +++ b/exir/debug_handle_utils.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.fx import Node + +FROM_NODE_KEY = "from_node" +DEBUG_HANDLE_KEY = "debug_handle" + + +def get_greatest_ancestor_node_identifier(node: Node) -> str: + """Get the identifier of the greatest ancestor node of the given node. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + + return f"{node_source.name}.{str(node_source.graph_id)}" diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 8699fe2fd02..0a1f5117f20 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -342,6 +342,7 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/exir:debug_handle_utils", "//executorch/exir:graph_module", "//executorch/exir:pass_base", ], diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 425558664b3..fe705273a51 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -6,6 +6,11 @@ from typing import Dict +from executorch.exir.debug_handle_utils import ( + DEBUG_HANDLE_KEY, + FROM_NODE_KEY, + get_greatest_ancestor_node_identifier, +) from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram @@ -21,27 +26,8 @@ def call(self, graph_module: GraphModule) -> PassResult: greatest ancestor node in the export flow. """ - FROM_NODE_KEY = "from_node" - DEBUG_HANDLE_KEY = "debug_handle" - source_node_id_to_debug_handle: Dict[str, int] = {} - def _get_greatest_ancestor_node_identifier(node: Node) -> str: - """Get the identifier of the greatest ancestor node of the given node. - - The identifier is the concatenation of the node name and graph id of the - greatest ancestor node, where the graph id is the unique id for every graph - module in the export flow and node name is unique within the same graph module. - """ - - node_source = node.meta[FROM_NODE_KEY] - node_source = node_source[-1] - - while len(node_source.from_node) > 0: - node_source = node_source.from_node[-1] - - return node_source.name + str(node_source.graph_id) - def _extract_debug_handles_from_node(node: Node) -> None: """ Generate a debug handle based on node's oldest ancestor node's name @@ -56,7 +42,7 @@ def _extract_debug_handles_from_node(node: Node) -> None: FROM_NODE_KEY in node.meta ), f"Node {node} does not have meta key {FROM_NODE_KEY}" - greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node) + greatest_ancestor_node_id = get_greatest_ancestor_node_identifier(node) debug_handle = ( len(source_node_id_to_debug_handle) + 1