Skip to content

Commit 1a02862

Browse files
raise error when trying to save an etrecord missing essential info (#13231)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13143 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/35/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/35/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/35/orig @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent 44a776f commit 1a02862

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ def __init__(
7070
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
7171
_representative_inputs: Optional[List[ProgramInput]] = None,
7272
):
73+
"""
74+
Please do not construct an ETRecord object directly.
75+
76+
If you want to create an ETRecord for logging AOT information to further analysis, please mark `generate_etrecord`
77+
as True in your export api, and get the ETRecord object from the `ExecutorchProgramManager`.
78+
For exmaple:
79+
```python
80+
exported_program = torch.export.export(model, inputs)
81+
edge_program = to_edge_transform_and_lower(exported_program, generate_etrecord=True)
82+
executorch_program = edge_program.to_executorch()
83+
etrecord = executorch_program.get_etrecord()
84+
```
85+
86+
If user need to create an ETRecord manually, please use the `create_etrecord` function.
87+
"""
88+
7389
self.exported_program = exported_program
7490
self.export_graph_id = export_graph_id
7591
self.edge_dialect_program = edge_dialect_program
@@ -81,15 +97,25 @@ def __init__(
8197

8298
def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
8399
"""
84-
Serialize and save the ETRecord to the specified path.
100+
Serialize and save the ETRecord to the specified path for use in Inspector. The ETRecord
101+
should contains at least edge dialect program and executorch program information for further
102+
analysis, otherwise it will raise an exception.
85103
86104
Args:
87105
path: Path where the ETRecord file will be saved to.
106+
107+
Raises:
108+
RuntimeError: If the ETRecord does not contain essential information for Inpector.
88109
"""
89110
if isinstance(path, (str, os.PathLike)):
90111
# pyre-ignore[6]: In call `os.fspath`, for 1st positional argument, expected `str` but got `Union[PathLike[typing.Any], str]`
91112
path = os.fspath(path)
92113

114+
if not (self.edge_dialect_program and self._debug_handle_map):
115+
raise RuntimeError(
116+
"ETRecord must contain edge dialect program and executorch program to be saved"
117+
)
118+
93119
etrecord_zip = ZipFile(path, "w")
94120

95121
try:

devtools/etrecord/tests/etrecord_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,3 +1499,32 @@ def test_update_apis_and_save_parse(self):
14991499
custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"]
15001500
):
15011501
self.assertTrue(torch.equal(expected[0], actual[0]))
1502+
1503+
def test_save_missing_essential_info(self):
1504+
def expected_runtime_error(etrecord, etrecord_path):
1505+
with self.assertRaises(RuntimeError) as context:
1506+
etrecord.save(etrecord_path)
1507+
1508+
self.assertIn(
1509+
"ETRecord must contain edge dialect program and executorch program to be saved",
1510+
str(context.exception),
1511+
)
1512+
1513+
"""Test that save raises RuntimeError when essential info is missing."""
1514+
_, edge_output, et_output = self.get_test_model()
1515+
1516+
etrecord = ETRecord()
1517+
1518+
with tempfile.TemporaryDirectory() as tmpdirname:
1519+
etrecord_path = tmpdirname + "/etrecord_no_edge.bin"
1520+
1521+
expected_runtime_error(etrecord, etrecord_path)
1522+
etrecord.add_edge_dialect_program(edge_output)
1523+
1524+
# Should raise runtime error due to missing executorch program related info
1525+
expected_runtime_error(etrecord, etrecord_path)
1526+
1527+
etrecord.add_executorch_program(et_output)
1528+
1529+
# All essential components are now present, so save should succeed
1530+
etrecord.save(etrecord_path)

0 commit comments

Comments
 (0)