diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index c5e4bbc9a06..17a7451aadf 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -62,6 +62,7 @@ map_runtime_aot_intermediate_outputs, merge_runtime_overlapping_debug_handles, ProgramOutput, + propagate_back_debug_handle, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, verify_debug_data_equivalence, @@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names( """ if self._etrecord._representative_inputs is None: return {}, {} - export_program = self._etrecord.edge_dialect_program + + export_program = None + + # Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program + if self._etrecord.exported_program and propagate_back_debug_handle( + self._etrecord.exported_program, + self._etrecord.export_graph_id, + self._etrecord.edge_dialect_program, + ): + export_program = self._etrecord.exported_program + else: + export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping( graph_module diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 11b7b4f70e3..2bda03b4873 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union +from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union import executorch.devtools.etdump.schema_flatcc as flatcc @@ -37,7 +37,7 @@ from executorch.exir.debug_handle_utils import ( DEBUG_HANDLE_KEY, - get_greatest_ancestor_node_identifier, + FROM_NODE_KEY, UNSET_DEBUG_HANDLE, ) @@ -46,6 +46,7 @@ from tabulate import tabulate from torch.export import ExportedProgram +from torch.fx import Node FORWARD = "forward" EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -936,6 +937,133 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: ) +def get_ancestor_node_identifiers(node: Node) -> List[str]: + """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifiers of all its ancestor nodes + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}") + + return ancestor_node_ids + + +def get_parent_node_identifier(node: Node) -> Optional[str]: + """Get the identifier of the parent node of the given node, with the graph id the parent node lives in. + + The identifier is the concatenation of the node name and graph id of the + greatest parent node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + + Returns: the identifier of the parent node, or None if can not find the parent + """ + + if FROM_NODE_KEY not in node.meta: + return None + + node_source = node.meta[FROM_NODE_KEY][-1] + return f"{node_source.name}.{str(node_source.graph_id)}" + + +def _extract_ancestor_debug_handles( + edge_dialect_program: ExportedProgram, +) -> Dict[str, int]: + """Extract mapping from ancestor node identifiers to debug handles.""" + ancestors_node_id_to_debug_handle: Dict[str, int] = {} + + def _extract_node_id_to_debug_handle(node: Node) -> None: + if node.op in ("placeholder", "output"): + return + for ancestor_node_id in get_ancestor_node_identifiers(node): + if ancestor_node_id not in ancestors_node_id_to_debug_handle: + ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[ + DEBUG_HANDLE_KEY + ] + else: + assert ( + ancestors_node_id_to_debug_handle[ancestor_node_id] + == node.meta[DEBUG_HANDLE_KEY] + ) + + bfs_trace_with_node_process( + edge_dialect_program.graph_module, _extract_node_id_to_debug_handle + ) + return ancestors_node_id_to_debug_handle + + +def _find_matched_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> Set[int]: + """Find debug handles that have corresponding nodes in the exported program.""" + matched_debug_handles: Set[int] = set() + + def _find_n_match_node(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id]) + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id]) + + bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + return matched_debug_handles + + +def _verify_graph_match( + edge_dialect_program: ExportedProgram, matched_debug_handles: Set[int] +) -> bool: + """Verify if every debug handle in edge dialect program has a corresponding node.""" + graph_matched = True + + def _check_graph_match(node: Node) -> None: + nonlocal graph_matched + if node.op in ("output", "placeholder"): + return + if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handles: + graph_matched = False + + bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match) + return graph_matched + + +def _apply_debug_handles( + exported_program: ExportedProgram, + exported_program_graph_id: int, + ancestors_node_id_to_debug_handle: Dict[str, int], +) -> None: + """Apply debug handles to the exported program nodes.""" + + def _equip_debug_handle(node: Node) -> None: + if node.op in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + parent_node_id = get_parent_node_identifier(node) + if node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id] + elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[ + parent_node_id + ] + else: + node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE + + bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + + def propagate_back_debug_handle( exported_program: ExportedProgram, exported_program_graph_id: int, @@ -953,47 +1081,24 @@ def propagate_back_debug_handle( Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping. - Return: True if: - a. every debug handle in the edge dialect program has a corresponding node in the exported program - b. the exported program is the greatest ancestor of the edge dialect program - - Otherwise, return False. + Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False. """ + # 1. Extract mapping from ancestor node identifiers to debug handles + ancestors_node_id_to_debug_handle = _extract_ancestor_debug_handles( + edge_dialect_program + ) - # 1. set up a mapping from debug handle to identifier of export program's node - # using edge dialect program nodes' debug handles and from_node info - export_graph_node_id_to_debug_handle = { - get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY] - for node in edge_dialect_program.graph.nodes - if node.op not in ("placeholder", "output") - } - - # 2. equip debug handle to the exported program's nodes using the mapping - # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle - n_matched_node = 0 - - def _find_n_match_node(node: torch.fx.Node) -> None: - nonlocal n_matched_node - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - n_matched_node += 1 - - def _equip_debug_handle(node: torch.fx.Node) -> None: - if node.name in ("output", "placeholder"): - return - node_id = f"{node.name}.{exported_program_graph_id}" - if node_id in export_graph_node_id_to_debug_handle: - node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] - else: - node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE - - bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + # 2. Find debug handles that have corresponding nodes in the exported program + matched_debug_handles = _find_matched_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) - # if any node in the edge dialect program has no corresponding node in the exported program, match failed - if n_matched_node != len(export_graph_node_id_to_debug_handle): + # 3. Verify if every debug handle in edge dialect program has a corresponding node + if not _verify_graph_match(edge_dialect_program, matched_debug_handles): return False - bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + # 4. Apply debug handles to the exported program + _apply_debug_handles( + exported_program, exported_program_graph_id, ancestors_node_id_to_debug_handle + ) return True diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index c36311afeab..37dc7921923 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -25,7 +25,6 @@ from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.et_schema import OperatorNode from executorch.devtools.etdump.schema_flatcc import ProfileEvent -from executorch.devtools.etrecord._etrecord import ETRecord from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector import ( @@ -480,7 +479,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): events=events, ) - def test_etrecord_populates_correct_aot_intermediate_outputs(self): + def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self): with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: etrecord_path = tmp_file.name mod = model_registry["ConvLinearModel"]() @@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): etdump_path=ETDUMP_PATH, etrecord=etrecord_path, ) - etrecord = ETRecord( - edge_dialect_program=inspector_instance._etrecord.edge_dialect_program, - graph_map=inspector_instance._etrecord.graph_map, - _debug_handle_map=inspector_instance._etrecord._debug_handle_map, - _delegate_map=inspector_instance._etrecord._delegate_map, - _reference_outputs=inspector_instance._etrecord._reference_outputs, - _representative_inputs=aten_model.example_inputs[0], + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] ) - inspector_instance._etrecord = etrecord + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( inspector_instance._get_aot_intermediate_outputs_and_op_names() ) @@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): self.assertTrue( check_if_debug_handle_to_op_names_match( - "ConvLinearModel", aot_debug_handle_to_op_names + aot_debug_handle_to_op_names, + mod.get_edge_dialect_expected_debug_handle_to_op_names(), + ) + ) + + def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self): + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: + etrecord_path = tmp_file.name + mod = model_registry["ConvLinearModel"]() + input_tensor = mod.get_input() + aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True) + edge_program_manager: EdgeProgramManager = to_edge(aten_model) + edge_program_manager_copy = copy.deepcopy(edge_program_manager) + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + # Generate ETRecord with the exported program + generate_etrecord( + etrecord_path, + edge_program_manager_copy, + et_program_manager, + exported_program=aten_model, + ) + with patch.object( + Inspector, "_consume_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=etrecord_path, + ) + + inspector_instance._etrecord._representative_inputs = ( + aten_model.example_inputs[0] + ) + + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + inspector_instance._get_aot_intermediate_outputs_and_op_names() + ) + self.assertTrue( + check_if_intermediate_outputs_match( + aot_intermediate_outputs, + mod.get_exported_program_expected_intermediate_outputs(), + ) + ) + self.assertTrue( + check_if_debug_handle_to_op_names_match( + aot_debug_handle_to_op_names, + mod.get_exported_program_expected_debug_handle_to_op_names(), ) ) diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index da426377564..69c787608b1 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -79,7 +79,7 @@ def get_edge_dialect_expected_intermediate_outputs(): } @staticmethod - def get_expected_debug_handle_to_op_names(): + def get_edge_dialect_expected_debug_handle_to_op_names(): """ Returns the expected debug handle and op names mapping for this model for the given input. """ @@ -100,7 +100,7 @@ def get_expected_debug_handle_to_op_names(): @staticmethod def get_exported_program_expected_intermediate_outputs(): """ - Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input. + Returns the expected outputs of the debug handles and intermediate output mapping for export graph of this model for the given input. """ return { (UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]), @@ -117,6 +117,26 @@ def get_exported_program_expected_intermediate_outputs(): (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } + @staticmethod + def get_exported_program_expected_debug_handle_to_op_names(): + """ + Returns the expected debug handle and op name mapping for this model for the given input. + """ + return { + (UNSET_DEBUG_HANDLE,): ["_assert_tensor_metadata_default", "to"], + (1,): ["conv2d"], + (2,): ["view"], + (3,): ["linear"], + (4,): ["add"], + (5,): ["sub"], + (6,): ["mul"], + (7,): ["add_1"], + (8,): ["div"], + (9,): ["relu"], + (10,): ["sigmoid"], + (11,): ["split"], + } + # Global model registry model_registry = { @@ -153,15 +173,13 @@ def check_if_intermediate_outputs_match( return True -def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name): +def check_if_debug_handle_to_op_names_match( + actual_debug_handle_to_op_name, expected_debug_handle_to_op_name +): """ Checks if the actual op names match the expected op names for the specified model. Returns True if all match, otherwise returns False. """ - model_instance = model_registry[model_name] - expected_debug_handle_to_op_name = ( - model_instance.get_expected_debug_handle_to_op_names() - ) if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name): return False for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items(): diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 69aa6f65dec..ea8c0e653af 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self): exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] ) + def test_equip_debug_handle_to_strict_export_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + exported_program = export(model, inputs, strict=True) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + + def test_equip_debug_handle_to_reexport_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + init_export_program = export(model, inputs) + exported_program = export(init_export_program.module(), inputs) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + def test_equip_debug_handle_to_export_program_failure(self): """Test that propagate_back_debug_handle returns False when there's a mismatch.""" # Create a test model