diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index de7cf93990a..d5ad81fe255 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -29,6 +29,7 @@ from executorch.exir.serde.export_serialize import SerializedArtifact from executorch.exir.serde.serialize import deserialize, serialize +ProgramInput = List[Value] ProgramOutput = List[Value] try: @@ -49,6 +50,7 @@ class ETRecordReservedFileNames(StrEnum): DEBUG_HANDLE_MAP_NAME = "debug_handle_map" DELEGATE_MAP_NAME = "delegate_map" REFERENCE_OUTPUTS = "reference_outputs" + REPRESENTATIVE_INPUTS = "representative_inputs" @dataclass @@ -60,6 +62,7 @@ class ETRecord: Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] ] = None _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None + _representative_inputs: Optional[List[ProgramOutput]] = None def _handle_exported_program( @@ -157,6 +160,24 @@ def _get_reference_outputs( return reference_outputs +def _get_representative_inputs( + bundled_program: BundledProgram, +) -> List[ProgramInput]: + """ + Extracts out the inputs from the bundled program, keyed by the method names. + """ + for method_test_suite in bundled_program.method_test_suites: + if method_test_suite.method_name == "forward": + if not method_test_suite.test_cases: + raise ValueError( + "The 'forward' method is defined, but no corresponding input test cases are provided." + ) + # Get first example input from the forward method + test_case = method_test_suite.test_cases[0] + return test_case.inputs + raise ValueError("No 'forward' method found in the bundled program.") + + def generate_etrecord( et_record: Union[str, os.PathLike, BinaryIO, IO[bytes]], edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], @@ -244,6 +265,13 @@ def generate_etrecord( # @lint-ignore PYTHONPICKLEISBAD pickle.dumps(reference_outputs), ) + + representative_inputs = _get_representative_inputs(executorch_program) + etrecord_zip.writestr( + ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, + # @lint-ignore PYTHONPICKLEISBAD + pickle.dumps(representative_inputs), + ) executorch_program = executorch_program.executorch_program etrecord_zip.writestr( @@ -290,6 +318,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 delegate_map = None edge_dialect_program = None reference_outputs = None + representative_inputs = None serialized_exported_program_files = set() serialized_state_dict_files = set() @@ -321,6 +350,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 reference_outputs = pickle.loads( etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS) ) + elif entry == ETRecordReservedFileNames.REPRESENTATIVE_INPUTS: + # @lint-ignore PYTHONPICKLEISBAD + representative_inputs = pickle.loads( + etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS) + ) else: if entry.endswith("state_dict"): serialized_state_dict_files.add(entry) @@ -352,4 +386,5 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 _debug_handle_map=debug_handle_map, _delegate_map=delegate_map, _reference_outputs=reference_outputs, + _representative_inputs=representative_inputs, ) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index cf50662c2a1..dd1d40e0292 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -19,6 +19,7 @@ from executorch.devtools.etrecord import generate_etrecord, parse_etrecord from executorch.devtools.etrecord._etrecord import ( _get_reference_outputs, + _get_representative_inputs, ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge @@ -135,15 +136,25 @@ def test_etrecord_generation_with_bundled_program(self): ) etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") - expected = etrecord._reference_outputs - actual = _get_reference_outputs(bundled_program) + expected_inputs = etrecord._representative_inputs + actual_inputs = _get_representative_inputs(bundled_program) # assertEqual() gives "RuntimeError: Boolean value of Tensor with more than one value is ambiguous" when comparing tensors, # so we use torch.equal() to compare the tensors one by one. + for expected, actual in zip(expected_inputs, actual_inputs): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + + expected_outputs = etrecord._reference_outputs + actual_outputs = _get_reference_outputs(bundled_program) self.assertTrue( - torch.equal(expected["forward"][0][0], actual["forward"][0][0]) + torch.equal( + expected_outputs["forward"][0][0], actual_outputs["forward"][0][0] + ) ) self.assertTrue( - torch.equal(expected["forward"][1][0], actual["forward"][1][0]) + torch.equal( + expected_outputs["forward"][1][0], actual_outputs["forward"][1][0] + ) ) def test_etrecord_generation_with_manager(self):