From 995ffcae711f63b5bd30bff6447963ffdd0a6184 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 4 Aug 2025 22:00:41 -0700 Subject: [PATCH] make etrecord set representive IO Pull Request resolved: https://github.com/pytorch/executorch/pull/13052 representive input and reference output in etrecord will not be set during export flow. To continue supporting the two functionalities, this diff creates two class methods to customize IO in etrecord. ghstack-source-id: 300740656 @exported-using-ghexport Differential Revision: [D79386896](https://our.internmc.facebook.com/intern/diff/D79386896/) --- devtools/etrecord/_etrecord.py | 52 +++- devtools/etrecord/tests/etrecord_test.py | 323 ++++++++++++++++++++++- 2 files changed, 370 insertions(+), 5 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 3b8a71279fd..6c8a55d6220 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -68,7 +68,7 @@ def __init__( Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] ] = None, _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, - _representative_inputs: Optional[List[ProgramOutput]] = None, + _representative_inputs: Optional[List[ProgramInput]] = None, ): self.exported_program = exported_program self.export_graph_id = export_graph_id @@ -345,6 +345,56 @@ def add_edge_dialect_program( # Set the extracted data self.edge_dialect_program = processed_edge_dialect_program + def update_representative_inputs( + self, + representative_inputs: Union[List[ProgramInput], BundledProgram], + ) -> None: + """ + Update the representative inputs in the ETRecord. + + This method allows users to customize the representative inputs that will be + included when the ETRecord is saved. The representative inputs can be provided + directly as a list or extracted from a BundledProgram. + + Args: + representative_inputs: Either a list of ProgramInput objects or a BundledProgram + from which representative inputs will be extracted. + """ + if isinstance(representative_inputs, BundledProgram): + self._representative_inputs = _get_representative_inputs( + representative_inputs + ) + else: + self._representative_inputs = representative_inputs + + def update_reference_outputs( + self, + reference_outputs: Union[ + Dict[str, List[ProgramOutput]], List[ProgramOutput], BundledProgram + ], + ) -> None: + """ + Update the reference outputs in the ETRecord. + + This method allows users to customize the reference outputs that will be + included when the ETRecord is saved. The reference outputs can be provided + directly as a dictionary mapping method names to lists of outputs, as a + single list of outputs (which will be treated as {"forward": List[ProgramOutput]}), + or extracted from a BundledProgram. + + Args: + reference_outputs: Either a dictionary mapping method names to lists of + ProgramOutput objects, a single list of ProgramOutput objects (treated + as outputs for the "forward" method), or a BundledProgram from which + reference outputs will be extracted. + """ + if isinstance(reference_outputs, BundledProgram): + self._reference_outputs = _get_reference_outputs(reference_outputs) + elif isinstance(reference_outputs, list): + self._reference_outputs = {"forward": reference_outputs} + else: + self._reference_outputs = reference_outputs + def _get_reference_outputs( bundled_program: BundledProgram, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 25ea5a25e1f..dbd7fdfb776 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -10,6 +10,7 @@ import json import tempfile import unittest +from typing import List import executorch.exir.tests.models as models import torch @@ -30,6 +31,42 @@ # TODO : T154728484 Add test cases to cover multiple entry points class TestETRecord(unittest.TestCase): + def assert_representative_inputs_equal( + self, + expected_inputs: List, + actual_inputs: List, + msg: str = "Representative inputs do not match", + ) -> None: + """ + Utility function to compare representative inputs. + + This function handles the comparison of representative inputs, which are lists of tuples + containing tensors. It compares each input tuple element by element using torch.equal(). + + Args: + expected_inputs: List of expected input tuples + actual_inputs: List of actual input tuples + msg: Optional message to display on assertion failure + """ + self.assertEqual( + len(expected_inputs), + len(actual_inputs), + f"{msg}: Different number of input sets", + ) + + for i, (expected, actual) in enumerate(zip(expected_inputs, actual_inputs)): + self.assertEqual( + len(expected), + len(actual), + f"{msg}: Input set {i} has different number of tensors", + ) + + for j, (exp_tensor, act_tensor) in enumerate(zip(expected, actual)): + self.assertTrue( + torch.equal(exp_tensor, act_tensor), + f"{msg}: Tensor {j} in input set {i} does not match", + ) + def assert_etrecord_has_no_exported_program(self, etrecord: ETRecord) -> None: """Assert that ETRecord has no exported program data.""" self.assertIsNone(etrecord.exported_program) @@ -73,8 +110,7 @@ def get_test_model(self): captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) captured_output_copy = copy.deepcopy(captured_output) edge_output = captured_output.to_edge( - # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops - exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + exir.EdgeCompileConfig(_check_ir_validity=False) ) edge_output_copy = copy.deepcopy(edge_output) et_output = edge_output.to_executorch() @@ -99,8 +135,7 @@ def get_test_model_with_bundled_program(self): captured_output = exir.capture(f, inputs[0], exir.CaptureConfig()) captured_output_copy = copy.deepcopy(captured_output) edge_output = captured_output.to_edge( - # TODO(gasoon): Remove _use_edge_ops=False once serde is fully migrated to Edge ops - exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + exir.EdgeCompileConfig(_check_ir_validity=False) ) edge_output_copy = copy.deepcopy(edge_output) et_output = edge_output.to_executorch() @@ -1230,3 +1265,283 @@ def test_add_all_programs_sequentially(self): parsed_etrecord._delegate_map, json.loads(json.dumps(et_output.delegate_map)), ) + + def test_update_representative_inputs_with_list(self): + """Test update_representative_inputs with a list of ProgramInput objects.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no representative inputs + self.assertIsNone(etrecord._representative_inputs) + + # Create custom representative inputs + f = models.BasicSinMax() + custom_inputs = [f.get_random_inputs() for _ in range(3)] + + # Update representative inputs + etrecord.update_representative_inputs(custom_inputs) + + # Verify representative inputs are now set + self.assertIsNotNone(etrecord._representative_inputs) + self.assertEqual(len(etrecord._representative_inputs), 3) + + # Compare the inputs using utility function + self.assert_representative_inputs_equal( + custom_inputs, + etrecord._representative_inputs, + "Custom inputs do not match ETRecord representative inputs", + ) + + def test_update_representative_inputs_with_bundled_program(self): + """Test update_representative_inputs with a BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + ) + + # Verify initial state - no representative inputs + self.assertIsNone(etrecord._representative_inputs) + + # Update representative inputs using bundled program + etrecord.update_representative_inputs(bundled_program) + + # Verify representative inputs are now set + self.assertIsNotNone(etrecord._representative_inputs) + + # Compare with expected inputs from bundled program using utility function + expected_inputs = _get_representative_inputs(bundled_program) + self.assert_representative_inputs_equal( + expected_inputs, + etrecord._representative_inputs, + "Bundled program inputs do not match ETRecord representative inputs", + ) + + def test_update_representative_inputs_overwrite_existing(self): + """Test that update_representative_inputs overwrites existing inputs.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance with existing representative inputs + initial_inputs = _get_representative_inputs(bundled_program) + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + _representative_inputs=initial_inputs, + ) + + # Verify initial inputs are set + self.assertIsNotNone(etrecord._representative_inputs) + + # Create new custom inputs + f = models.BasicSinMax() + new_inputs = [f.get_random_inputs() for _ in range(2)] + + # Update representative inputs with new inputs + etrecord.update_representative_inputs(new_inputs) + + # Verify inputs are updated using utility function + self.assertEqual(len(etrecord._representative_inputs), 2) + self.assert_representative_inputs_equal( + new_inputs, + etrecord._representative_inputs, + "New inputs do not match ETRecord representative inputs after overwrite", + ) + + def test_update_reference_outputs_with_dict(self): + """Test update_reference_outputs with a dictionary of outputs.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Create custom reference outputs + f = models.BasicSinMax() + inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs = { + "forward": [f.forward(*inp) for inp in inputs], + "custom_method": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], + } + + # Update reference outputs + etrecord.update_reference_outputs(custom_outputs) + + # Verify reference outputs are now set + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + self.assertIn("custom_method", etrecord._reference_outputs) + + # Compare the outputs + self.assertEqual(len(etrecord._reference_outputs["forward"]), 2) + self.assertEqual(len(etrecord._reference_outputs["custom_method"]), 2) + + for expected, actual in zip( + custom_outputs["forward"], etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + + for expected, actual in zip( + custom_outputs["custom_method"], + etrecord._reference_outputs["custom_method"], + ): + self.assertTrue(torch.equal(expected, actual)) + + def test_update_reference_outputs_with_list(self): + """Test update_reference_outputs with a single list of outputs.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Create custom reference outputs as a single list + f = models.BasicSinMax() + inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs_list = [f.forward(*inp) for inp in inputs] + + # Update reference outputs with a single list + etrecord.update_reference_outputs(custom_outputs_list) + + # Verify reference outputs are now set and treated as "forward" method + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + self.assertEqual(len(etrecord._reference_outputs["forward"]), 2) + + # Compare the outputs + for expected, actual in zip( + custom_outputs_list, etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + + def test_update_reference_outputs_with_bundled_program(self): + """Test update_reference_outputs with a BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + ) + + # Verify initial state - no reference outputs + self.assertIsNone(etrecord._reference_outputs) + + # Update reference outputs using bundled program + etrecord.update_reference_outputs(bundled_program) + + # Verify reference outputs are now set + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIn("forward", etrecord._reference_outputs) + + # Compare with expected outputs from bundled program + expected_outputs = _get_reference_outputs(bundled_program) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][0][0], + expected_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][1][0], + expected_outputs["forward"][1][0], + ) + ) + + def test_update_apis_and_save_parse(self): + """Test that ETRecord with updated inputs/outputs can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create custom inputs and outputs + f = models.BasicSinMax() + custom_inputs = [f.get_random_inputs() for _ in range(2)] + custom_outputs = { + "forward": [f.forward(*inp) for inp in custom_inputs], + } + + # Update both inputs and outputs + etrecord.update_representative_inputs(custom_inputs) + etrecord.update_reference_outputs(custom_outputs) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_custom_data.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify representative inputs are preserved using utility function + self.assertIsNotNone(parsed_etrecord._representative_inputs) + self.assertEqual(len(parsed_etrecord._representative_inputs), 2) + self.assert_representative_inputs_equal( + custom_inputs, + parsed_etrecord._representative_inputs, + "Custom inputs do not match parsed ETRecord representative inputs", + ) + + # Verify reference outputs are preserved + self.assertIsNotNone(parsed_etrecord._reference_outputs) + self.assertIn("forward", parsed_etrecord._reference_outputs) + self.assertEqual(len(parsed_etrecord._reference_outputs["forward"]), 2) + for expected, actual in zip( + custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"] + ): + self.assertTrue(torch.equal(expected[0], actual[0]))