diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index c797208c0c9..ed07531b70b 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -658,7 +658,7 @@ def _populate_debugging_related_fields( def _associate_with_op_graph_nodes( self, - debug_handle_to_op_node_map: Dict[int, OperatorNode], + debug_handle_to_op_node_map: Dict[int, List[OperatorNode]], ) -> None: """ Helper function to populate the stack_traces, module_hierarchy and op_types attributes @@ -676,14 +676,21 @@ def _associate_with_op_graph_nodes( debug_handles = [debug_handles] for handle in debug_handles: - node = debug_handle_to_op_node_map.get(handle) - # Attach node metadata including stack traces, module hierarchy and op_types to this event - if node is not None and (metadata := node.metadata) is not None: - self.stack_traces[node.name] = metadata.get("stack_trace") - self.module_hierarchy[node.name] = metadata.get("nn_module_stack") - if node.op: - # TODO: consider having this as a dict from node.name -> node.op - self.op_types += [node.op] + nodes = debug_handle_to_op_node_map.get(handle, None) + if nodes is None: + continue + + for node in nodes: + # Attach node metadata including stack traces, module hierarchy and op_types to this event + if node is not None and (metadata := node.metadata) is not None: + if node.name not in self.stack_traces: + self.stack_traces[node.name] = metadata.get("stack_trace") + self.module_hierarchy[node.name] = metadata.get( + "nn_module_stack" + ) + if node.op: + # TODO: consider having this as a dict from node.name -> node.op + self.op_types += [node.op] @dataclass diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index eed8f89b1f7..8df804901f7 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -303,14 +303,23 @@ def gen_graphs_from_etrecord( return op_graph_map +# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation +# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge +# graph. After fully migration, we should bring the bring type as well as the #node check back. +# +# Before migration: returned Dict for 1 debug handle to 1 node in to_edge graph +# During migration: returned Dict for 1 debug handle to multiple nodes in to_edge graph +# After migration: returned Dict for 1 debug handle to 1 node in exported graph +# +# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check. def create_debug_handle_to_op_node_mapping( op_graph: OperatorGraph, -) -> Dict[int, OperatorNode]: +) -> Dict[int, List[OperatorNode]]: """ Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping from each debug handle to the operator node that contains the debug handle in its metadata. """ - debug_handle_to_op_node_map: Dict[int, OperatorNode] = {} + debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {} # Recursively searches through the metadata of nodes def _extract_debug_handles(graph: OperatorGraph): @@ -320,14 +329,13 @@ def _extract_debug_handles(graph: OperatorGraph): if isinstance(element, OperatorNode) and element.metadata is not None: metadata = element.metadata debug_handle = metadata.get("debug_handle") - if debug_handle is not None: - existing_entry = debug_handle_to_op_node_map.get(debug_handle) - if existing_entry is not None: - raise ValueError( - f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. " - "No two op nodes of the same graph should have the same debug handle." - ) - debug_handle_to_op_node_map[debug_handle] = element + if debug_handle is None: + continue + + if debug_handle not in debug_handle_to_op_node_map: + debug_handle_to_op_node_map[debug_handle] = [] + + debug_handle_to_op_node_map[debug_handle].append(element) # Start traversing _extract_debug_handles(op_graph) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 17a9101d894..7c294d81571 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -183,7 +183,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self): # Call the method that's under testing and verify event_with_single_debug_handle._associate_with_op_graph_nodes( - {debug_handle: node_0} + { + debug_handle: [ + node_0, + ] + } ) expected_stack_traces = {"node_0": "stack_trace_relu"} @@ -226,7 +230,14 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self): # Call the method that's under testing and verify event_with_multiple_debug_handles._associate_with_op_graph_nodes( - {debug_handles[0]: node_0, debug_handles[1]: node_1} + { + debug_handles[0]: [ + node_0, + ], + debug_handles[1]: [ + node_1, + ], + } ) expected_stack_traces = { diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index ef36bd6a178..f07bbe0035f 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -62,25 +62,17 @@ def get_expected_intermediate_outputs(): Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input. """ return { - (10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), - (11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), - (12,): torch.tensor( - [ - [0.1000, 0.5000], - [0.2000, 0.6000], - [0.3000, 0.7000], - [0.4000, 0.8000], - ] - ), - (13,): torch.tensor([[5.0000, 14.1200]]), - (14,): torch.tensor([[5.5000, 13.6200]]), - (15,): torch.tensor([[5.4000, 13.5200]]), - (16,): torch.tensor([[10.8000, 6.7600]]), - (17,): torch.tensor([3.0000, 1.5000]), - (18,): torch.tensor([[3.6000, 4.5067]]), - (19,): torch.tensor([[3.6000, 4.5067]]), - (20,): torch.tensor([[0.9734, 0.9891]]), - (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], + (1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), + (2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), + (3,): torch.tensor([[5.0000, 14.1200]]), + (4,): torch.tensor([[5.5000, 13.6200]]), + (5,): torch.tensor([[5.4000, 13.5200]]), + (6,): torch.tensor([[10.8000, 6.7600]]), + (7,): torch.tensor([3.0000, 1.5000]), + (8,): torch.tensor([[3.6000, 4.5067]]), + (9,): torch.tensor([[3.6000, 4.5067]]), + (10,): torch.tensor([[0.9734, 0.9891]]), + (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], } @staticmethod @@ -89,18 +81,17 @@ def get_expected_debug_handle_to_op_name(): Returns the expected debug handle and op name mapping for this model for the given input. """ return { - (10,): "aten_convolution_default", - (11,): "aten_view_copy_default", - (12,): "aten_permute_copy_default", - (13,): "aten_addmm_default", - (14,): "aten_add_tensor", - (15,): "aten_sub_tensor", - (16,): "aten_mul_tensor", - (17,): "aten_add_tensor_1", - (18,): "aten_div_tensor", - (19,): "aten_relu_default", - (20,): "aten_sigmoid_default", - (21,): "aten_split_with_sizes_copy_default", + (1,): "aten_convolution_default", + (2,): "aten_view_copy_default", + (3,): "aten_addmm_default", + (4,): "aten_add_tensor", + (5,): "aten_sub_tensor", + (6,): "aten_mul_tensor", + (7,): "aten_add_tensor_1", + (8,): "aten_div_tensor", + (9,): "aten_relu_default", + (10,): "aten_sigmoid_default", + (11,): "aten_split_with_sizes_copy_default", } diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index cda34e462fd..d7707ffa199 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -583,7 +583,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_relu", }, ) - mapping[111] = node_fused_conv_relu + mapping[111] = [ + node_fused_conv_relu, + ] node_sin = OperatorNode( "sin", [node_fused_conv_relu], @@ -594,7 +596,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_sin", }, ) - mapping[222] = node_sin + mapping[222] = [ + node_sin, + ] node_cos = OperatorNode( "cos", [node_sin], @@ -605,7 +609,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_cos", }, ) - mapping[333] = node_cos + mapping[333] = [ + node_cos, + ] node_div = OperatorNode( "div", [node_cos], @@ -616,7 +622,9 @@ def gen_mock_operator_graph_with_expected_map() -> ( "nn_module_stack": "module_hierarchy_div", }, ) - mapping[444] = node_div + mapping[444] = [ + node_div, + ] node_output = ValueNode("output", [node_div]) return ( OperatorGraph( diff --git a/exir/backend/test/qnn_backend_demo.py b/exir/backend/test/qnn_backend_demo.py index 795711a0dd0..1823cea79cf 100644 --- a/exir/backend/test/qnn_backend_demo.py +++ b/exir/backend/test/qnn_backend_demo.py @@ -24,7 +24,9 @@ def preprocess( ) -> PreprocessResult: processed_bytes = "imqnncompiled" all_nodes_debug_handle = [ - node.meta["debug_handle"] for node in edge_program.graph.nodes + node.meta["debug_handle"] + for node in edge_program.graph.nodes + if node.op not in ("placeholder", "output") ] return PreprocessResult( processed_bytes=bytes(processed_bytes, encoding="utf8"), diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index b5a38d875c2..7576fed8eb2 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -194,7 +194,7 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check the delegate instruction @@ -414,7 +414,7 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check the delegate instruction @@ -1320,7 +1320,7 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check the delegate instruction diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index be9527b8ccb..b6aea7f8bb3 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -227,7 +227,7 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check the delegate instruction @@ -437,7 +437,7 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check the delegate instruction diff --git a/exir/backend/test/test_debug_handle_map.py b/exir/backend/test/test_debug_handle_map.py index b02d4633382..c6d426cf082 100644 --- a/exir/backend/test/test_debug_handle_map.py +++ b/exir/backend/test/test_debug_handle_map.py @@ -97,7 +97,13 @@ def test_lowered_the_whole_model(self, unlift): all_debug_handles = list(lowered_model.meta["debug_handle_map"].values())[0] self.assertEqual( len(all_debug_handles), - len(lowered_model.original_module.graph.nodes), + len( + [ + node + for node in lowered_model.original_module.graph.nodes + if node.op not in ("placeholder", "output") + ] + ), ) class ComposedModel(torch.nn.Module): @@ -127,5 +133,11 @@ def forward(self, *args): )[0] self.assertEqual( len(all_debug_handles), - len(lowered_node.original_module.graph.nodes), + len( + [ + node + for node in lowered_node.original_module.graph.nodes + if node.op not in ("placeholder", "output") + ] + ), ) diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index 827cb8cdebc..2c30e4d9531 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -29,6 +29,7 @@ def forward(self, x): model = Model() model_inputs = (torch.ones(1, 1),) + program = ( exir.capture(model, model_inputs, exir.CaptureConfig(pt2_mode=True)) .to_edge() @@ -37,7 +38,7 @@ def forward(self, x): # Create nodes for testing mapping # nodes: [arg0_1, alloc, aten_sin_default, alloc_1, aten_cos_default, output] - # debug handles: [0, None, 1, None, 2, 3] + # debug handles: [None, None, 1, None, 2, None] self.nodes = list(program.graph_module.graph.nodes) self.handles = [node.meta.get("debug_handle") for node in self.nodes] @@ -45,30 +46,30 @@ def forward(self, x): def test_basic_generated_identifier(self): delegate_builder = DelegateMappingBuilder(generated_identifiers=True) - expected_mapping = {0: (1, 2, 3, 4)} + expected_mapping = {0: (1, 2)} self.assertEqual( delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0 ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (1, 2, 3, 4), 1: (1,)} + expected_mapping = {0: (1, 2), 1: (1,)} self.assertEqual( - delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1 + delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[2]), 1 ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)} + expected_mapping = {0: (1, 2), 1: (1,), 2: (2,)} self.assertEqual( - delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]), + delegate_builder.insert_delegate_mapping_entry(handles=self.handles[4]), 2, ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) expected_mapping = { - 0: (1, 2, 3, 4), + 0: (1, 2), 1: (1,), 2: (2,), - 3: (1, 2, 3, 4), + 3: (1, 2), } self.assertEqual( delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3 @@ -114,7 +115,7 @@ def test_omitting_identifier_when_not_generated(self): def test_reinsert_delegate_debug_identifier(self): delegate_builder = DelegateMappingBuilder() delegate_builder.insert_delegate_mapping_entry( - nodes=self.nodes[0], identifier="1" + nodes=self.nodes[2], identifier="1" ) self.assertRaises( @@ -130,6 +131,19 @@ def test_reinsert_delegate_debug_identifier(self): ), ) + self.assertRaises( + Exception, + lambda: delegate_builder.insert_delegate_mapping_entry( + nodes=self.nodes[2], identifier="1" + ), + ) + self.assertRaises( + Exception, + lambda: delegate_builder.insert_delegate_mapping_entry( + handles=self.handles[2], identifier="1" + ), + ) + def test_backend_with_delegate_mapping(self) -> None: model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs() edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge( @@ -200,7 +214,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): # Entry with a list of nodes iden_1 = next(identifiers) - expected_mapping = {iden_1: (1, 2, 3, 4)} + expected_mapping = {iden_1: (1, 2)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( nodes=self.nodes, identifier=iden_1 @@ -222,16 +236,16 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): # Entry with a single node iden_2 = next(identifiers) - expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)} + expected_mapping = {iden_1: (1, 2), iden_2: (1,)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( - nodes=self.nodes[0], identifier=iden_2 + nodes=self.nodes[2], identifier=iden_2 ), iden_2, ) self.assertEqual( delegate_builder_handles.insert_delegate_mapping_entry( - handles=self.handles[0], identifier=iden_2 + handles=self.handles[2], identifier=iden_2 ), iden_2, ) diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index e9320cf415d..d369a914fac 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -166,7 +166,7 @@ def partition( if not is_param(edge_exported_program, node) and not is_buffer( edge_exported_program, node ): - delegation_tag = "tag_" + str(node.meta["debug_handle"]) + delegation_tag = "tag_" + str(node.name) node.meta["delegation_tag"] = delegation_tag partition_tags[delegation_tag] = self.delegation_spec diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py index 045de253e0f..606a9db6e7d 100644 --- a/exir/backend/test/test_to_backend_multi_method.py +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -555,13 +555,13 @@ def forward(self, x): program=program, delegate=program.execution_plan[0].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) self.check_backend_delegate( program=program, delegate=program.execution_plan[1].delegates[0], expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", ) # Check that there are two methods diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 7de8676084b..425558664b3 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -4,31 +4,75 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict + from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram -from torch.fx import GraphModule +from torch.fx import GraphModule, Node from torch.fx.passes.infra.pass_base import PassResult class DebugHandleGeneratorPass(ExportPass): def call(self, graph_module: GraphModule) -> PassResult: - """Lower a quantized reference model (with reference quantized operator patterns) - to executorch backend, that has a canonical set of quantized operators + """Generate debug handles for each node in the graph module and its submodule except + placeholder and output nodes. The debug handle is generated starting from 1 and + incrementally. The debug handle of a node is the same as the node sharing the same + greatest ancestor node in the export flow. """ - index = 1 + FROM_NODE_KEY = "from_node" + DEBUG_HANDLE_KEY = "debug_handle" + + source_node_id_to_debug_handle: Dict[str, int] = {} + + def _get_greatest_ancestor_node_identifier(node: Node) -> str: + """Get the identifier of the greatest ancestor node of the given node. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + + return node_source.name + str(node_source.graph_id) + + def _extract_debug_handles_from_node(node: Node) -> None: + """ + Generate a debug handle based on node's oldest ancestor node's name + and graph id, or return None if the node does not need to be traced. + """ + + if node.op == "placeholder" or node.op == "output": + # placeholder and output nodes don't have debug handle + return + + assert ( + FROM_NODE_KEY in node.meta + ), f"Node {node} does not have meta key {FROM_NODE_KEY}" + + greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node) + + debug_handle = ( + len(source_node_id_to_debug_handle) + 1 + if greatest_ancestor_node_id not in source_node_id_to_debug_handle + else source_node_id_to_debug_handle[greatest_ancestor_node_id] + ) - def _extract_debug_handles_from_node(node): - nonlocal index - node.meta["debug_handle"] = index - index += 1 + source_node_id_to_debug_handle[greatest_ancestor_node_id] = debug_handle + node.meta[DEBUG_HANDLE_KEY] = debug_handle bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node) return PassResult(graph_module, True) +# TODO(gasoonjia): generate missing debug handles using `from_node` info def generate_missing_debug_handles(ep: ExportedProgram): """ This pass is used to generate missing debug handles for the graph module and its submodules. diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index a9dabad6234..656e20e2fb7 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -67,7 +67,7 @@ from executorch.exir.tensor import TensorSpec from executorch.exir.tests.common import register_additional_test_aten_ops from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic -from executorch.exir.tests.models import MLP, Mul +from executorch.exir.tests.models import FeedForwardBlock, MLP, Mul from functorch.experimental import control_flow from torch import nn @@ -883,11 +883,79 @@ def test_debug_handle_generator_pass(self) -> None: .exported_program() .graph_module ) + + # Every node except input and output should have debug handle for node in graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) ScalarToTensorPass()(graph_module) + + for node in graph_module.graph.nodes: + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) + + def test_debug_handle_generator_pass_generate_same_debug_handle_on_ops_sharing_same_source( + self, + ) -> None: + eager_model = FeedForwardBlock(256, 512) + inputs = (torch.randn(12, 256),) + + graph_module = ( + to_edge(export(eager_model, inputs, strict=True)) + .exported_program() + .graph_module + ) + + same_source_nodes = { + "aten_native_layer_norm_default": ( + "aten_native_layer_norm_default", + "getitem", + ), + "getitem": ("aten_native_layer_norm_default", "getitem"), + "aten_permute_copy_default": ( + "aten_permute_copy_default", + "aten_addmm_default", + ), + "aten_addmm_default": ("aten_permute_copy_default", "aten_addmm_default"), + "aten_native_dropout_default": ("aten_native_dropout_default", "getitem_1"), + "getitem_1": ("aten_native_dropout_default", "getitem_1"), + "aten_relu_default": ("aten_relu_default",), + "aten_permute_copy_default_1": ( + "aten_permute_copy_default_1", + "aten_addmm_default_1", + ), + "aten_addmm_default_1": ( + "aten_permute_copy_default_1", + "aten_addmm_default_1", + ), + "aten_native_dropout_default_1": ( + "aten_native_dropout_default_1", + "getitem_2", + ), + "getitem_2": ("aten_native_dropout_default_1", "getitem_2"), + } + + node_name_to_debug_handle = {} + + # Node having same source should have same debug handle for node in graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) + if node.name in node_name_to_debug_handle: + for node_name_with_same_debug_handle in same_source_nodes[ + node.name + ]: + self.assertEqual( + node_name_to_debug_handle[node_name_with_same_debug_handle], + node.meta["debug_handle"], + ) + else: + for node_name_with_same_debug_handle in same_source_nodes[ + node.name + ]: + node_name_to_debug_handle[node_name_with_same_debug_handle] = ( + node.meta["debug_handle"] + ) def test_generate_missing_debug_handles(self) -> None: eager_model = MLP(2, output_size=4) @@ -895,10 +963,15 @@ def test_generate_missing_debug_handles(self) -> None: ep = to_edge(export(eager_model, inputs, strict=True)).exported_program() - list(ep.graph.nodes)[0].meta.pop("debug_handle") - self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None) + # get the first non-placeholder node + first_non_placeholder_node = [ + n for n in ep.graph.nodes if n.op != "placeholder" + ][0] + + first_non_placeholder_node.meta.pop("debug_handle") + self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is None) generate_missing_debug_handles(ep) - self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None) + self.assertTrue(first_non_placeholder_node.meta.get("debug_handle") is not None) def test_debug_handle_generator_pass_with_control_flow(self) -> None: def true_nested(y: torch.Tensor) -> torch.Tensor: @@ -952,7 +1025,8 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: while queue: current_graph_module = queue.pop(0) for node in current_graph_module.graph.nodes: - self.assertIn("debug_handle", node.meta) + if node.op != "placeholder" and node.op != "output": + self.assertIn("debug_handle", node.meta) control_flow_submodules = [ submodule for _, submodule, _ in get_control_flow_submodules( @@ -963,7 +1037,6 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: DebugHandleGeneratorPass()(graph_module) check_debug_handle_metadata(graph_module) - generate_missing_debug_handles(ep) # Check debug handle still preserved after ScalarToTensorPass ScalarToTensorPass()(graph_module)