From 0a0f9860df23bba36523ef96d53f74d52a9bb4f0 Mon Sep 17 00:00:00 2001 From: Juntian Liu Date: Wed, 25 Jun 2025 14:42:17 -0700 Subject: [PATCH] Add function to map AOT debug_handles to op names (#11930) Summary: This PR adds a function to map AOT debug handles to operator names in the Export graph. It will be used later to enhance how numerical discrepancy results are shown, making it easier for users to understand. Reviewed By: Gasoonjia Differential Revision: D77244175 --- devtools/inspector/TARGETS | 1 + devtools/inspector/_inspector.py | 5 + devtools/inspector/_inspector_utils.py | 50 ++++++++ .../_intermediate_output_capturer.py | 25 +--- .../inspector/tests/inspector_utils_test.py | 108 ++++++++++++++++++ 5 files changed, 166 insertions(+), 23 deletions(-) diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index 0712bdf1f9a..d32698f784f 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -56,6 +56,7 @@ python_library( "_intermediate_output_capturer.py", ], deps = [ + "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index dfff3d0818e..7e2d35af6c2 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -52,6 +52,7 @@ FORWARD, gen_etdump_object, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, inflate_runtime_output, is_debug_output, is_inference_output_equal, @@ -1084,6 +1085,7 @@ def __init__( self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None + self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None self._consume_etrecord() def _consume_etrecord(self) -> None: @@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None: return export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() + self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping( + graph_module + ) capturer = IntermediateOutputCapturer(graph_module) self._aot_intermediate_outputs = capturer.run_and_capture( self._etrecord._representative_inputs diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 61e2ea4d031..50b3669309c 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -93,6 +93,28 @@ class NodeData: output: Any +class NodeFilter: + """ + A class used to filter nodes based on extensible criteria. + Attributes: + metadata_key (str): The key to look for in the node's metadata. + op_type (str): The operation code to match. + exclude_ops (List[str]): A list of operations to exclude from the filter. + """ + + def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): + self.metadata_key = metadata_key + self.op_type = op_type + self.exclude_ops = exclude_ops + + def matches(self, node: torch.fx.Node) -> bool: + return ( + node.meta.get(self.metadata_key) is not None + and node.op == self.op_type + and all(exclude_name not in node.name for exclude_name in self.exclude_ops) + ) + + def calculate_time_scale_factor( source_time_scale: TimeScale, target_time_scale: TimeScale ) -> float: @@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: if torch.isnan(input_tensor).any(): input_tensor = torch.nan_to_num(input_tensor) return input_tensor + + +def get_aot_debug_handle_to_op_name_mapping( + graph_module: torch.fx.GraphModule, +) -> Dict[Tuple[int, ...], str]: + """ + Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module. + Parameters: + graph_module (torch.fx.GraphModule): The graph module to get the mapping from. + Returns: + Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names. + """ + node_filters = [ + NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) + ] + + debug_handle_to_op_name = {} + for node in graph_module.graph.nodes: + if all(filter.matches(node) for filter in node_filters): + debug_handle = node.meta["debug_handle"] + # Convert the debug handle to a tuple to use as a dictionary key + key = ( + (debug_handle,) + if isinstance(debug_handle, int) + else tuple(debug_handle) + ) + debug_handle_to_op_name[key] = node.name + return debug_handle_to_op_name diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py index c1f943bd02c..054c97dc245 100644 --- a/devtools/inspector/_intermediate_output_capturer.py +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -7,35 +7,14 @@ # pyre-unsafe -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Tuple import torch +from executorch.devtools.inspector._inspector_utils import NodeFilter from torch.fx import GraphModule from torch.fx.interpreter import Interpreter -class NodeFilter: - """ - A class used to filter nodes based on extensible criteria. - Attributes: - metadata_key (str): The key to look for in the node's metadata. - op_type (str): The operation code to match. - exclude_ops (List[str]): A list of operations to exclude from the filter. - """ - - def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): - self.metadata_key = metadata_key - self.op_type = op_type - self.exclude_ops = exclude_ops - - def matches(self, node: torch.fx.Node) -> bool: - return ( - node.meta.get(self.metadata_key) is not None - and node.op == self.op_type - and all(exclude_name not in node.name for exclude_name in self.exclude_ops) - ) - - class IntermediateOutputCapturer(Interpreter): """ A class that captures intermediate outputs from a PyTorch graph module. diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8148d2c36f0..6d12cb13c5f 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -34,9 +34,11 @@ EDGE_DIALECT_GRAPH_KEY, find_populated_event, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, is_inference_output_equal, map_runtime_aot_intermediate_outputs, merge_overlapping_debug_handles, + NodeFilter, TimeScale, ) @@ -364,6 +366,112 @@ class X: msg = str(cm.exception) self.assertIn("Cannot convert value of type", msg) + def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self): + # Create a simple graph module with one node + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node.meta["debug_handle"] = 1 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {(1,): "op1"} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self): + # Create a simple graph module with two nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node1 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node1.meta["debug_handle"] = (1, 2) + node2 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op2" + ) + node2.meta["debug_handle"] = 3 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = { + ( + 1, + 2, + ): "op1", + (3,): "op2", + } + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self): + # Create a simple graph module with no nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_node_filter_match(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + # Create a mock node that matches the filter criteria + mock_node = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node.meta["debug_handle"] = (1, 2) + # Test that the filter matches the mock node + self.assertTrue(node_filter.matches(mock_node)) + + def test_node_filter_key_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + mock_node_metadata_key_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_metadata_key_mismatch", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + # Test that the filter doesn't match the mock node (meta doesn't have debug_handle key) + self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch)) + + def test_node_filter_ops_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_exclude_ops_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="getitem", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (exclude_ops mismatch) + self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch)) + + def test_node_op_type_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_op_type_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_op_type_mismatch", + op="get_attr", + target="torch.nn.functional.relu", + args=(), + kwargs={}, + ) + mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (op_type mismatch) + self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]]