From 80c8ed84557099d041634949e4c1709b10623de3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 28 Jul 2025 10:11:52 -0700 Subject: [PATCH 1/2] support back propagate debug handle to arbitrary ancestor export graph Pull Request resolved: https://github.com/pytorch/executorch/pull/12580 Currently propagate_back_debug_handle function can only support propagating debug handle back to the greatest ancestor export graph. This diff update algo to support every possible ancestor export graph on the flow. ghstack-source-id: 299039097 Differential Revision: [D78464992](https://our.internmc.facebook.com/intern/diff/D78464992/) --- devtools/inspector/_inspector_utils.py | 185 ++++++++++++++---- .../inspector/tests/inspector_utils_test.py | 89 +++++++++ 2 files changed, 234 insertions(+), 40 deletions(-) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 11b7b4f70e3..2bda03b4873 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union import executorch.devtools.etdump.schema_flatcc as flatcc @@ -37,7 +37,7 @@ from executorch.exir.debug_handle_utils import ( DEBUG_HANDLE_KEY, - get_greatest_ancestor_node_identifier, + FROM_NODE_KEY, UNSET_DEBUG_HANDLE, ) @@ -46,6 +46,7 @@ from tabulate import tabulate from torch.export import ExportedProgram +from torch.fx import Node FORWARD = "forward" EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: ) +def get_ancestor_node_identifiers(node: Node) -> List[str]: + """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifiers of all its ancestor nodes + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}") + + return ancestor_node_ids + + +def get_parent_node_identifier(node: Node) -> Optional[str]: + """Get the identifier of the parent node of the given node, with the graph id the parent node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest parent node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifier of the parent node, or None if can not find the parent + """ + + if FROM_NODE_KEY not in node.meta: + return None + + node_source = node.meta[FROM_NODE_KEY][-1] + return f"{node_source.name}.{str(node_source.graph_id)}" + + +def _extract_ancestor_debug_handles( + edge_dialect_program: ExportedProgram, +) -> Dict[str, int]: + """Extract mapping from ancestor node identifiers to debug handles.""" + ancestors_node_id_to_debug_handle: Dict[str, int] = {} + + def _extract_node_id_to_debug_handle(node: Node) -> None: + if node.op in ("placeholder", "output"): + return + for ancestor_node_id in get_ancestor_node_identifiers(node): + if ancestor_node_id not in ancestors_node_id_to_debug_handle: + ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[ + DEBUG_HANDLE_KEY + ] + else: + assert ( + ancestors_node_id_to_debug_handle[ancestor_node_id] + == node.meta[DEBUG_HANDLE_KEY] + ) + + bfs_trace_with_node_process( + edge_dialect_program.graph_module, _extract_node_id_to_debug_handle + ) + return ancestors_node_id_to_debug_handle + + +def _find_matched_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> Set[int]: + """Find debug handles that have corresponding nodes in the exported program.""" + matched_debug_handles: Set[int] = set() + + def _find_n_match_node(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + + bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + return matched_debug_handles + + +def _verify_graph_match( + edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int] +) -> bool: + """Verify if every debug handle in edge dialect program has a corresponding node.""" + graph_matched = True + + def _check_graph_match(node: Node) -> None: + nonlocal graph_matched + if node.op in ("output", "placeholder"): + return + if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles: + graph_matched = False + + bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match) + return graph_matched + + +def _apply_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> None: + """Apply debug handles to the exported program nodes.""" + + def _equip_debug_handle(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + else: + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + + bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + + def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, @@ -953,47 +1081,24 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. - Return: True if: - a. every debug handle in the edge dialect program has a corresponding node in the exported program - b. the exported program is the greatest ancestor of the edge dialect program - - Otherwise, return False. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ + # 1. Extract mapping from ancestor node identifiers to debug handles + ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles( + edge_dialect_program + ) - # 1. set up a mapping from debug handle to identifier of export program's node - # using edge dialect program nodes' debug handles and from_node info - export_graph_node_id_to_debug_handle = { - get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY] - for node in edge_dialect_program.graph.nodes - if node.op not in ("placeholder", "output") - } - - # 2. equip debug handle to the exported program's nodes using the mapping - # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle - n_matched_node = 0 - - def _find_n_match_node(node: torch.fx.Node) -> None: - nonlocal n_matched_node - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - n_matched_node += 1 - - def _equip_debug_handle(node: torch.fx.Node) -> None: - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE - - bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + # 2. Find debug handles that have corresponding nodes in the exported program + matched_debug_handles = _find_matched_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) - # if any node in the edge dialect program has no corresponding node in the exported program, match failed - if n_matched_node != len(export_graph_node_id_to_debug_handle): + # 3. Verify if every debug handle in edge dialect program has a corresponding node + if not _verify_graph_match(edge_dialect_program, matched_debug_handles): return False - bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + # 4. Apply debug handles to the exported program + _apply_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) return True diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 69aa6f65dec..ea8c0e653af 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self): exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] ) + def test_equip_debug_handle_to_strict_export_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + exported_program = export(model, inputs, strict=True) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + + def test_equip_debug_handle_to_reexport_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + init_export_program = export(model, inputs) + exported_program = export(init_export_program.module(), inputs) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + def test_equip_debug_handle_to_export_program_failure(self): """Test that propagate_back_debug_handle returns False when there's a mismatch.""" # Create a test model From d05ee885e1acdd7aa18894f8c7d2e0577dff3c42 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 28 Jul 2025 10:17:18 -0700 Subject: [PATCH 2/2] Support export program in intermediate numeric discrepancy detector Pull Request resolved: https://github.com/pytorch/executorch/pull/12581 This diff enables intermediate numeric discrepancy detector to leverage export program as label. More specific, if user creates etrecord with exported program, and the exported program is one of the exported programs in the export flow, then our numeric discrepancy detector will use it as label. Otherwise, we will continue use edge dialect graph as label. ghstack-source-id: 299040326 @exported-using-ghexport Differential Revision: [D78298935](https://our.internmc.facebook.com/intern/diff/D78298935/) --- devtools/inspector/_inspector.py | 14 +++- devtools/inspector/tests/inspector_test.py | 71 ++++++++++++++++--- .../inspector/tests/inspector_test_utils.py | 32 +++++++-- 3 files changed, 98 insertions(+), 19 deletions(-) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index c5e4bbc9a06..17a7451aadf 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -62,6 +62,7 @@ map_runtime_aot_intermediate_outputs, merge_runtime_overlapping_debug_handles, ProgramOutput, + propagate_back_debug_handle, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, verify_debug_data_equivalence, @@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names( """ if self._etrecord._representative_inputs is None: return {}, {} - export_program = self._etrecord.edge_dialect_program + + export_program = None + + # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + ): + export_program = self._etrecord.exported_program + else: + export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index c36311afeab..37dc7921923 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -25,7 +25,6 @@ from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.et_schema import OperatorNode from executorch.devtools.etdump.schema_flatcc import ProfileEvent -from executorch.devtools.etrecord._etrecord import ETRecord from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector import ( @@ -480,7 +479,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): events=events, ) - def test_etrecord_populates_correct_aot_intermediate_outputs(self): + def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self): with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: etrecord_path = tmp_file.name mod = model_registry["ConvLinearModel"]() @@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): etdump_path=ETDUMP_PATH, etrecord=etrecord_path, ) - etrecord = ETRecord( - edge_dialect_program=inspector_instance._etrecord.edge_dialect_program, - graph_map=inspector_instance._etrecord.graph_map, - _debug_handle_map=inspector_instance._etrecord._debug_handle_map, - _delegate_map=inspector_instance._etrecord._delegate_map, - _reference_outputs=inspector_instance._etrecord._reference_outputs, - _representative_inputs=aten_model.example_inputs[0], + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] ) - inspector_instance._etrecord = etrecord + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( inspector_instance._get_aot_intermediate_outputs_and_op_names() ) @@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): self.assertTrue( check_if_debug_handle_to_op_names_match( - "ConvLinearModel", aot_debug_handle_to_op_names + aot_debug_handle_to_op_names, + mod.get_edge_dialect_expected_debug_handle_to_op_names(), + ) + ) + + def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self): + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + mod = model_registry["ConvLinearModel"]() + input_tensor = mod.get_input() + aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True) + edge_program_manager: EdgeProgramManager = to_edge(aten_model) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + # Generate ETRecord with the exported program + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=aten_model, + ) + with patch.object( + Inspector, "_consume_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] + ) + + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + inspector_instance._get_aot_intermediate_outputs_and_op_names() + ) + self.assertTrue( + check_if_intermediate_outputs_match( + aot_intermediate_outputs, + mod.get_exported_program_expected_intermediate_outputs(), + ) + ) + self.assertTrue( + check_if_debug_handle_to_op_names_match( + aot_debug_handle_to_op_names, + mod.get_exported_program_expected_debug_handle_to_op_names(), ) ) diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index da426377564..69c787608b1 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -79,7 +79,7 @@ def get_edge_dialect_expected_intermediate_outputs(): } @staticmethod - def get_expected_debug_handle_to_op_names(): + def get_edge_dialect_expected_debug_handle_to_op_names(): """ Returns the expected debug handle and op names mapping for this model for the given input. """ @@ -100,7 +100,7 @@ def get_expected_debug_handle_to_op_names(): @staticmethod def get_exported_program_expected_intermediate_outputs(): """ - Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input. + Returns the expected outputs of the debug handles and intermediate output mapping for export graph of this model for the given input. """ return { (UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]), @@ -117,6 +117,26 @@ def get_exported_program_expected_intermediate_outputs(): (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } + @staticmethod + def get_exported_program_expected_debug_handle_to_op_names(): + """ + Returns the expected debug handle and op name mapping for this model for the given input. + """ + return { + (UNSET_DEBUG_HANDLE,): ["_assert_tensor_metadata_default", "to"], + (1,): ["conv2d"], + (2,): ["view"], + (3,): ["linear"], + (4,): ["add"], + (5,): ["sub"], + (6,): ["mul"], + (7,): ["add_1"], + (8,): ["div"], + (9,): ["relu"], + (10,): ["sigmoid"], + (11,): ["split"], + } + # Global model registry model_registry = { @@ -153,15 +173,13 @@ def check_if_intermediate_outputs_match( return True -def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name): +def check_if_debug_handle_to_op_names_match( + actual_debug_handle_to_op_name, expected_debug_handle_to_op_name +): """ Checks if the actual op names match the expected op names for the specified model. Returns True if all match, otherwise returns False. """ - model_instance = model_registry[model_name] - expected_debug_handle_to_op_name = ( - model_instance.get_expected_debug_handle_to_op_names() - ) if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name): return False for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():