From de271fee6de7cbb852eda39ec973021415f86d78 Mon Sep 17 00:00:00 2001 From: Juntian Liu Date: Wed, 28 May 2025 18:39:55 -0700 Subject: [PATCH] Create the `IntermediateOutputCapturer` Class to Store the IntermediateOutput of the AOT Graph (#11202) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11202 This Diff introduces a new Python class, IntermediateOutputCapturer, which inherits from torch.fx.interpreter.Interpreter. The primary purpose of this class is to capture the output tensor(s) produced by each operator (node) for an EdgeProgramManager's GraphModule. We will use these stored outputs to compare with later runtime operator outputs to detect numerical discrepancies. The IntermediateOutputCapturer class overrides the run_node method to store the computed results in an instance dictionary. It checks for the presence of a debug_handle in the node's metadata and the type of the node and uses it as a key to store the result. Tensors are detached and cloned to prevent side effects, while non-tensor results are stored directly. A public method, run_and_capture, is implemented to call the base Interpreter's run method and return the dictionary containing the captured debug_handle -> output mappings. Additionally, an __init__ method is provided to accept an fx.GraphModule as input and a print_captured_outputs method is included for debugging purposes. Differential Revision: D75492919 --- devtools/inspector/TARGETS | 9 ++ .../_intermediate_output_capturer.py | 50 +++++++ devtools/inspector/tests/TARGETS | 11 ++ .../intermediate_output_capturer_test.py | 134 ++++++++++++++++++ 4 files changed, 204 insertions(+) create mode 100644 devtools/inspector/_intermediate_output_capturer.py create mode 100644 devtools/inspector/tests/intermediate_output_capturer_test.py diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index bba5f7f8951..ea6d55f8658 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -48,6 +48,15 @@ python_library( ], ) +python_library( + name = "intermediate_output_capturer", + srcs = [ + "_intermediate_output_capturer.py", + ], + deps = [ + ], +) + python_library( name = "lib", srcs = ["__init__.py"], diff --git a/devtools/inspector/_intermediate_output_capturer.py b/devtools/inspector/_intermediate_output_capturer.py new file mode 100644 index 00000000000..e3a904487eb --- /dev/null +++ b/devtools/inspector/_intermediate_output_capturer.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from typing import Any, Dict, Tuple + +import torch +from torch.fx import GraphModule +from torch.fx.interpreter import Interpreter + + +class IntermediateOutputCapturer(Interpreter): + def __init__(self, module: GraphModule): + super().__init__(module) + + def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]: + captured_outputs = {} + + def capture_run_node(n: torch.fx.Node) -> Any: + result = super(IntermediateOutputCapturer, self).run_node(n) + debug_handle = n.meta.get("debug_handle", None) + if debug_handle is not None and n.op == "call_function": + # Convert the debug handle to a tuple to use as a dictionary key + key = ( + (debug_handle,) + if isinstance(debug_handle, int) + else tuple(debug_handle) + ) + # Handle tensor results by detaching and cloning + if isinstance(result, torch.Tensor): + captured_outputs[key] = result.detach().clone() + elif isinstance(result, (tuple, list)): + captured_outputs[key] = [ + r.detach().clone() if isinstance(r, torch.Tensor) else r + for r in result + ] + else: + captured_outputs[key] = result + return result + + original_run_node = self.run_node + self.run_node = capture_run_node + self.run(*args, **kwargs) + self.run_node = original_run_node + return captured_outputs diff --git a/devtools/inspector/tests/TARGETS b/devtools/inspector/tests/TARGETS index eada6817bcb..78450dc5fe2 100644 --- a/devtools/inspector/tests/TARGETS +++ b/devtools/inspector/tests/TARGETS @@ -39,3 +39,14 @@ python_unittest( "//executorch/devtools/inspector:inspector_utils", ], ) + +python_unittest( + name = "intermediate_output_capturer_test", + srcs = ["intermediate_output_capturer_test.py"], + deps = [ + "//executorch/devtools/inspector:inspector", + "//executorch/devtools/inspector:lib", + "//executorch/devtools/inspector:intermediate_output_capturer", + "//executorch/exir:lib", + ], +) diff --git a/devtools/inspector/tests/intermediate_output_capturer_test.py b/devtools/inspector/tests/intermediate_output_capturer_test.py new file mode 100644 index 00000000000..e6dd782d887 --- /dev/null +++ b/devtools/inspector/tests/intermediate_output_capturer_test.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.devtools.inspector._intermediate_output_capturer import ( + IntermediateOutputCapturer, +) + +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from torch.export import export, ExportedProgram +from torch.fx import GraphModule + + +class TestIntermediateOutputCapturer(unittest.TestCase): + @classmethod + def setUpClass(cls): + class TestModule(nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.conv = nn.Conv2d( + in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 + ) + self.conv.weight = nn.Parameter( + torch.tensor( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]] + ) + ) + self.conv.bias = nn.Parameter(torch.tensor([0.0])) + + self.linear = nn.Linear(in_features=4, out_features=2) + self.linear.weight = nn.Parameter( + torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) + ) + self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0])) + self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False) + self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False) + + def forward(self, x): + x = self.conv(x) + x = x.view(x.size(0), -1) + x = self.linear(x) + x = x + self.bias + x = x - 0.1 + x = x * self.scale + x = x / (self.scale + 1.0) + x = F.relu(x) + x = torch.sigmoid(x) + x1, x2 = torch.split(x, 1, dim=1) + return x1, x2 + + cls.model = TestModule() + cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True) + cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True) + cls.edge_program_manager: EdgeProgramManager = to_edge( + cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True) + ) + cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[ + "forward" + ].module() + cls.capturer = IntermediateOutputCapturer(cls.graph_module) + cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input) + + def test_keying_with_debug_handle_tuple(self): + for key in self.intermediate_outputs.keys(): + self.assertIsInstance(key, tuple) + + def test_tensor_cloning_and_detaching(self): + for output in self.intermediate_outputs.values(): + if isinstance(output, torch.Tensor): + self.assertFalse(output.requires_grad) + self.assertTrue(output.is_leaf) + + def test_placeholder_nodes_are_skipped(self): + for node in self.graph_module.graph.nodes: + if node.op == "placeholder": + self.assertNotIn( + node.meta.get("debug_handle"), self.intermediate_outputs + ) + + def test_multiple_outputs_capture(self): + outputs = self.capturer.run_and_capture(self.input) + for output in outputs.values(): + if isinstance(output, tuple): + self.assertEqual(len(output), 2) + for part in output: + self.assertIsInstance(part, torch.Tensor) + + def test_capture_correct_outputs(self): + expected_outputs_with_handles = { + (10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), + (11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), + (12,): torch.tensor( + [[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]] + ), + (13,): torch.tensor([[5.0000, 14.1200]]), + (14,): torch.tensor([[5.5000, 13.6200]]), + (15,): torch.tensor([[5.4000, 13.5200]]), + (16,): torch.tensor([[10.8000, 6.7600]]), + (17,): torch.tensor([3.0000, 1.5000]), + (18,): torch.tensor([[3.6000, 4.5067]]), + (19,): torch.tensor([[3.6000, 4.5067]]), + (20,): torch.tensor([[0.9734, 0.9891]]), + (21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], + (22,): torch.tensor([[0.9734]]), + (23,): torch.tensor([[0.9891]]), + } + self.assertEqual( + len(self.intermediate_outputs), len(expected_outputs_with_handles) + ) + + for debug_handle, expected_output in expected_outputs_with_handles.items(): + actual_output = self.intermediate_outputs.get(debug_handle) + self.assertIsNotNone(actual_output) + if isinstance(expected_output, list): + self.assertIsInstance(actual_output, list) + self.assertEqual(len(actual_output), len(expected_output)) + for actual, expected in zip(actual_output, expected_output): + self.assertTrue( + torch.allclose(actual, expected, rtol=1e-4, atol=1e-5) + ) + else: + self.assertTrue( + torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5) + )