diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index 0712bdf1f9a..d32698f784f 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -56,6 +56,7 @@ python_library( "_intermediate_output_capturer.py", ], deps = [ + "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index dfff3d0818e..7e2d35af6c2 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -52,6 +52,7 @@ FORWARD, gen_etdump_object, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, inflate_runtime_output, is_debug_output, is_inference_output_equal, @@ -1084,6 +1085,7 @@ def __init__( self._reference_outputs: Dict[str, List[ProgramOutput]] = {} self._enable_module_hierarchy = enable_module_hierarchy self._aot_intermediate_outputs: Optional[Dict[Tuple[int, ...], Any]] = None + self._aot_debug_handles_to_op_names: Optional[Dict[Tuple[int, ...], str]] = None self._consume_etrecord() def _consume_etrecord(self) -> None: @@ -1150,6 +1152,9 @@ def _consume_etrecord(self) -> None: return export_program = self._etrecord.edge_dialect_program graph_module = export_program.module() + self._aot_debug_handles_to_op_names = get_aot_debug_handle_to_op_name_mapping( + graph_module + ) capturer = IntermediateOutputCapturer(graph_module) self._aot_intermediate_outputs = capturer.run_and_capture( self._etrecord._representative_inputs diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 61e2ea4d031..50b3669309c 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -93,6 +93,28 @@ class NodeData: output: Any +class NodeFilter: + """ + A class used to filter nodes based on extensible criteria. + Attributes: + metadata_key (str): The key to look for in the node's metadata. + op_type (str): The operation code to match. + exclude_ops (List[str]): A list of operations to exclude from the filter. + """ + + def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): + self.metadata_key = metadata_key + self.op_type = op_type + self.exclude_ops = exclude_ops + + def matches(self, node: torch.fx.Node) -> bool: + return ( + node.meta.get(self.metadata_key) is not None + and node.op == self.op_type + and all(exclude_name not in node.name for exclude_name in self.exclude_ops) + ) + + def calculate_time_scale_factor( source_time_scale: TimeScale, target_time_scale: TimeScale ) -> float: @@ -734,3 +756,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: if torch.isnan(input_tensor).any(): input_tensor = torch.nan_to_num(input_tensor) return input_tensor + + +def get_aot_debug_handle_to_op_name_mapping( + graph_module: torch.fx.GraphModule, +) -> Dict[Tuple[int, ...], str]: + """ + Get a mapping from debug handle to operator name from the ETRecord edge_dialect_program's graph module. + Parameters: + graph_module (torch.fx.GraphModule): The graph module to get the mapping from. + Returns: + Dict[Tuple[int, ...], str]: A dictionary mapping debug handles to operator names. + """ + node_filters = [ + NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) + ] + + debug_handle_to_op_name = {} + for node in graph_module.graph.nodes: + if all(filter.matches(node) for filter in node_filters): + debug_handle = node.meta["debug_handle"] + # Convert the debug handle to a tuple to use as a dictionary key + key = ( + (debug_handle,) + if isinstance(debug_handle, int) + else tuple(debug_handle) + ) + debug_handle_to_op_name[key] = node.name + return debug_handle_to_op_name diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py index c1f943bd02c..054c97dc245 100644 --- a/devtools/inspector/_intermediate_output_capturer.py +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -7,35 +7,14 @@ # pyre-unsafe -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Tuple import torch +from executorch.devtools.inspector._inspector_utils import NodeFilter from torch.fx import GraphModule from torch.fx.interpreter import Interpreter -class NodeFilter: - """ - A class used to filter nodes based on extensible criteria. - Attributes: - metadata_key (str): The key to look for in the node's metadata. - op_type (str): The operation code to match. - exclude_ops (List[str]): A list of operations to exclude from the filter. - """ - - def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): - self.metadata_key = metadata_key - self.op_type = op_type - self.exclude_ops = exclude_ops - - def matches(self, node: torch.fx.Node) -> bool: - return ( - node.meta.get(self.metadata_key) is not None - and node.op == self.op_type - and all(exclude_name not in node.name for exclude_name in self.exclude_ops) - ) - - class IntermediateOutputCapturer(Interpreter): """ A class that captures intermediate outputs from a PyTorch graph module. diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 8148d2c36f0..6d12cb13c5f 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -34,9 +34,11 @@ EDGE_DIALECT_GRAPH_KEY, find_populated_event, gen_graphs_from_etrecord, + get_aot_debug_handle_to_op_name_mapping, is_inference_output_equal, map_runtime_aot_intermediate_outputs, merge_overlapping_debug_handles, + NodeFilter, TimeScale, ) @@ -364,6 +366,112 @@ class X: msg = str(cm.exception) self.assertIn("Cannot convert value of type", msg) + def test_get_aot_debug_handle_to_op_name_mapping_single_debug_handle(self): + # Create a simple graph module with one node + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node.meta["debug_handle"] = 1 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {(1,): "op1"} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_multiple_debug_handles(self): + # Create a simple graph module with two nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + node1 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op1" + ) + node1.meta["debug_handle"] = (1, 2) + node2 = graph_module.graph.create_node( + "call_function", target=torch.mul, args=(), kwargs={}, name="op2" + ) + node2.meta["debug_handle"] = 3 + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = { + ( + 1, + 2, + ): "op1", + (3,): "op2", + } + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_get_aot_debug_handle_to_op_name_mapping_no_debug_handles(self): + # Create a simple graph module with no nodes + graph_module = torch.fx.GraphModule({}, torch.fx.Graph()) + debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(graph_module) + expected_result = {} + self.assertEqual(debug_handle_to_op_name, expected_result) + + def test_node_filter_match(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + # Create a mock node that matches the filter criteria + mock_node = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node.meta["debug_handle"] = (1, 2) + # Test that the filter matches the mock node + self.assertTrue(node_filter.matches(mock_node)) + + def test_node_filter_key_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + mock_node_metadata_key_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_metadata_key_mismatch", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + # Test that the filter doesn't match the mock node (meta doesn't have debug_handle key) + self.assertFalse(node_filter.matches(mock_node_metadata_key_mismatch)) + + def test_node_filter_ops_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_exclude_ops_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="getitem", + op="call_function", + target=torch.nn.functional.relu, + args=(), + kwargs={}, + ) + mock_node_exclude_ops_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (exclude_ops mismatch) + self.assertFalse(node_filter.matches(mock_node_exclude_ops_mismatch)) + + def test_node_op_type_mismatch(self): + node_filter = NodeFilter( + "debug_handle", "call_function", exclude_ops=["getitem"] + ) + + mock_node_op_type_mismatch = torch.fx.Node( + graph=torch.fx.Graph(), + name="mock_node_op_type_mismatch", + op="get_attr", + target="torch.nn.functional.relu", + args=(), + kwargs={}, + ) + mock_node_op_type_mismatch.meta["debug_handle"] = (1, 2) + # Test that the filter doesn't match the mock node (op_type mismatch) + self.assertFalse(node_filter.matches(mock_node_op_type_mismatch)) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]]