From 3fd434bc8c530c7355efbeed18b506e02c38f0e0 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 8 Aug 2025 00:22:12 -0700 Subject: [PATCH] raise error when trying to save an etrecord missing essential info Pull Request resolved: https://github.com/pytorch/executorch/pull/13143 as title ghstack-source-id: 301624419 @exported-using-ghexport Differential Revision: [D79687142](https://our.internmc.facebook.com/intern/diff/D79687142/) --- devtools/etrecord/_etrecord.py | 28 ++++++++++++++++++++++- devtools/etrecord/tests/etrecord_test.py | 29 ++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 6c8a55d6220..3906dcb1030 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -70,6 +70,22 @@ def __init__( _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, _representative_inputs: Optional[List[ProgramInput]] = None, ): + """ + Please do not construct an ETRecord object directly. + + If you want to create an ETRecord for logging AOT information to further analysis, please mark `generate_etrecord` + as True in your export api, and get the ETRecord object from the `ExecutorchProgramManager`. + For exmaple: + ```python + exported_program = torch.export.export(model, inputs) + edge_program = to_edge_transform_and_lower(exported_program, generate_etrecord=True) + executorch_program = edge_program.to_executorch() + etrecord = executorch_program.get_etrecord() + ``` + + If user need to create an ETRecord manually, please use the `create_etrecord` function. + """ + self.exported_program = exported_program self.export_graph_id = export_graph_id self.edge_dialect_program = edge_dialect_program @@ -81,15 +97,25 @@ def __init__( def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None: """ - Serialize and save the ETRecord to the specified path. + Serialize and save the ETRecord to the specified path for use in Inspector. The ETRecord + should contains at least edge dialect program and executorch program information for further + analysis, otherwise it will raise an exception. Args: path: Path where the ETRecord file will be saved to. + + Raises: + RuntimeError: If the ETRecord does not contain essential information for Inpector. """ if isinstance(path, (str, os.PathLike)): # pyre-ignore[6]: In call `os.fspath`, for 1st positional argument, expected `str` but got `Union[PathLike[typing.Any], str]` path = os.fspath(path) + if not (self.edge_dialect_program and self._debug_handle_map): + raise RuntimeError( + "ETRecord must contain edge dialect program and executorch program to be saved" + ) + etrecord_zip = ZipFile(path, "w") try: diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 8ca9bd0c2eb..b3b47c679ee 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -1462,3 +1462,32 @@ def test_update_apis_and_save_parse(self): custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"] ): self.assertTrue(torch.equal(expected[0], actual[0])) + + def test_save_missing_essential_info(self): + def expected_runtime_error(etrecord, etrecord_path): + with self.assertRaises(RuntimeError) as context: + etrecord.save(etrecord_path) + + self.assertIn( + "ETRecord must contain edge dialect program and executorch program to be saved", + str(context.exception), + ) + + """Test that save raises RuntimeError when essential info is missing.""" + _, edge_output, et_output = self.get_test_model() + + etrecord = ETRecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_no_edge.bin" + + expected_runtime_error(etrecord, etrecord_path) + etrecord.add_edge_dialect_program(edge_output) + + # Should raise runtime error due to missing executorch program related info + expected_runtime_error(etrecord, etrecord_path) + + etrecord.add_executorch_program(et_output) + + # All essential components are now present, so save should succeed + etrecord.save(etrecord_path)