diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 32a46ab0276..d49ce3959a6 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -35,8 +35,17 @@ from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc from executorch.devtools.etrecord import ETRecord +from executorch.exir.debug_handle_utils import ( + DEBUG_HANDLE_KEY, + get_greatest_ancestor_node_identifier, +) + +from executorch.exir.graph_module import bfs_trace_with_node_process + from tabulate import tabulate +from torch.export import ExportedProgram + FORWARD = "forward" EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" @@ -888,3 +897,71 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]: else: # Raise an error if one is a sequence and the other is not raise ValueError("Both inputs must be sequences or both must be non-sequences.") + + +def propagate_back_debug_handle( + exported_program: ExportedProgram, + exported_program_graph_id: int, + edge_dialect_program: ExportedProgram, +) -> bool: + """ + Propagate debug handle from edge dialect program back to the exported program while maintain the correctness + of operator tracing. + + e.g. + export program: op1 -> op2 -> op3 + edge dialect program: op1_0 -> op3_0 -> op3_1 + where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass). + + Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1. + The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping. + + Return: True if: + a. every debug handle in the edge dialect program has a corresponding node in the exported program + b. the exported program is the greatest ancestor of the edge dialect program + + Otherwise, return False. + """ + + # 1. set up a mapping from debug handle to identifier of export program's node + # using edge dialect program nodes' debug handles and from_node info + export_graph_node_id_to_debug_handle = { + get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY] + for node in edge_dialect_program.graph.nodes + if node.op not in ("placeholder", "output") + } + + # 2. equip debug handle to the exported program's nodes using the mapping + # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle + n_matched_node = 0 + + # debug handle for the node in the exported program but not in the edge dialect program + debug_handle_for_removed_node = ( + max(export_graph_node_id_to_debug_handle.values()) + 1 + ) + + def _find_n_match_node(node: torch.fx.Node) -> None: + nonlocal n_matched_node + if node.name in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + if node_id in export_graph_node_id_to_debug_handle: + n_matched_node += 1 + + def _equip_debug_handle(node: torch.fx.Node) -> None: + if node.name in ("output", "placeholder"): + return + node_id = f"{node.name}.{exported_program_graph_id}" + if node_id in export_graph_node_id_to_debug_handle: + node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id] + else: + node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node + + bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node) + + # if any node in the edge dialect program has no corresponding node in the exported program, match failed + if n_matched_node != len(export_graph_node_id_to_debug_handle): + return False + + bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle) + return True diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 47113910e98..a77d541cb06 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -10,8 +10,9 @@ import unittest from typing import Dict, Tuple -import torch +import executorch.exir.tests.models as models +import torch from executorch.devtools import generate_etrecord, parse_etrecord from executorch.devtools.debug_format.base_schema import ( @@ -41,9 +42,13 @@ map_runtime_aot_intermediate_outputs, merge_runtime_overlapping_debug_handles, NodeFilter, + propagate_back_debug_handle, TimeScale, ) from executorch.devtools.inspector.numerical_comparator import L1Comparator +from executorch.exir import to_edge +from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY +from torch.export import export class TestInspectorUtils(unittest.TestCase): @@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self): with self.assertRaises(ValueError): compare_intermediate_outputs(a, b, L1Comparator()) + def test_equip_debug_handle_to_export_program_success(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + # Export the model + exported_program = export(model, inputs) + export_graph_id = id(exported_program.graph) + + # Convert to edge dialect + edge_dialect_program = to_edge(exported_program).exported_program() + + # Call propagate_back_debug_handle + result = propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + + self.assertTrue(result) + + # Check that debug handles are properly equipped in the exported program + exported_program_debug_handles = [] + for node in exported_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + edge_dialect_program_debug_handles = [] + for node in edge_dialect_program.graph.nodes: + if node.op not in ("placeholder", "output"): + self.assertIn(DEBUG_HANDLE_KEY, node.meta) + self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY]) + edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY]) + + # The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem) + # So they should have the same debug handle + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[0] + ) + self.assertEqual( + exported_program_debug_handles[0], edge_dialect_program_debug_handles[1] + ) + + def test_equip_debug_handle_to_export_program_failure(self): + """Test that propagate_back_debug_handle returns False when there's a mismatch.""" + # Create a test model + model = models.FeedForwardBlock(5, 10) + inputs = (torch.rand(5, 5),) + + exported_program = export(model, inputs) + edge_dialect_program = to_edge(exported_program).exported_program() + + # Create a different exported program (reexport) to cause mismatch + reexported_program = export(model, inputs) + reexport_graph_id = id(reexported_program.graph) + + # Call propagate_back_debug_handle with mismatched programs + # This should return False because the reexported program has different node identifiers + result = propagate_back_debug_handle( + reexported_program, reexport_graph_id, edge_dialect_program + ) + + # Check that it returns False due to mismatch + self.assertFalse(result) + + def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self): + """Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge""" + + class M(torch.nn.Module): + """ + Simple model with ops that will be removed in to_edge + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + 1 + x = x.to(x.dtype) + x = x + 1 + return x + + inputs = (torch.rand(5, 5),) + exported_program = torch.export.export(M(), inputs) + export_graph_id = id(exported_program.graph) + edge_dialect_program = to_edge(exported_program).exported_program() + + self.assertTrue( + propagate_back_debug_handle( + exported_program, export_graph_id, edge_dialect_program + ) + ) + + # only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three + debug_handle_for_removed_node = 3 + + for node in exported_program.graph.nodes: + if node.name == "add": + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1) + elif node.name == "add_1": + self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2) + elif node.op not in ("placeholder", "output"): + self.assertEqual( + node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node + ) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/exir/TARGETS b/exir/TARGETS index 7916cec29fb..cda57de7f80 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -277,3 +277,11 @@ python_library( "fbsource//third-party/pypi/typing-extensions:typing-extensions", ], ) + +python_library( + name = "debug_handle_utils", + srcs = ["debug_handle_utils.py"], + deps = [ + "//caffe2:torch", + ], +) diff --git a/exir/debug_handle_utils.py b/exir/debug_handle_utils.py new file mode 100644 index 00000000000..d1a70fcd213 --- /dev/null +++ b/exir/debug_handle_utils.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.fx import Node + +FROM_NODE_KEY = "from_node" +DEBUG_HANDLE_KEY = "debug_handle" + + +def get_greatest_ancestor_node_identifier(node: Node) -> str: + """Get the identifier of the greatest ancestor node of the given node. + + The identifier is the concatenation of the node name and graph id of the + greatest ancestor node, where the graph id is the unique id for every graph + module in the export flow and node name is unique within the same graph module. + """ + + node_source = node.meta[FROM_NODE_KEY] + node_source = node_source[-1] + + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + + return f"{node_source.name}.{str(node_source.graph_id)}" diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 8699fe2fd02..0a1f5117f20 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -342,6 +342,7 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/exir:debug_handle_utils", "//executorch/exir:graph_module", "//executorch/exir:pass_base", ], diff --git a/exir/passes/debug_handle_generator_pass.py b/exir/passes/debug_handle_generator_pass.py index 425558664b3..fe705273a51 100644 --- a/exir/passes/debug_handle_generator_pass.py +++ b/exir/passes/debug_handle_generator_pass.py @@ -6,6 +6,11 @@ from typing import Dict +from executorch.exir.debug_handle_utils import ( + DEBUG_HANDLE_KEY, + FROM_NODE_KEY, + get_greatest_ancestor_node_identifier, +) from executorch.exir.graph_module import bfs_trace_with_node_process from executorch.exir.pass_base import ExportPass from torch.export import ExportedProgram @@ -21,27 +26,8 @@ def call(self, graph_module: GraphModule) -> PassResult: greatest ancestor node in the export flow. """ - FROM_NODE_KEY = "from_node" - DEBUG_HANDLE_KEY = "debug_handle" - source_node_id_to_debug_handle: Dict[str, int] = {} - def _get_greatest_ancestor_node_identifier(node: Node) -> str: - """Get the identifier of the greatest ancestor node of the given node. - - The identifier is the concatenation of the node name and graph id of the - greatest ancestor node, where the graph id is the unique id for every graph - module in the export flow and node name is unique within the same graph module. - """ - - node_source = node.meta[FROM_NODE_KEY] - node_source = node_source[-1] - - while len(node_source.from_node) > 0: - node_source = node_source.from_node[-1] - - return node_source.name + str(node_source.graph_id) - def _extract_debug_handles_from_node(node: Node) -> None: """ Generate a debug handle based on node's oldest ancestor node's name @@ -56,7 +42,7 @@ def _extract_debug_handles_from_node(node: Node) -> None: FROM_NODE_KEY in node.meta ), f"Node {node} does not have meta key {FROM_NODE_KEY}" - greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node) + greatest_ancestor_node_id = get_greatest_ancestor_node_identifier(node) debug_handle = ( len(source_node_id_to_debug_handle) + 1