Skip to content

Commit e5131ef

Browse files
committed
raise error when trying to save an etrecord missing essential info
as title Differential Revision: [D79687142](https://our.internmc.facebook.com/intern/diff/D79687142/) [ghstack-poisoned]
1 parent 3652afd commit e5131ef

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,19 @@ def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
8585
8686
Args:
8787
path: Path where the ETRecord file will be saved to.
88+
89+
Raises:
90+
RuntimeError: If the ETRecord does not contain essential information for Inpector.
8891
"""
8992
if isinstance(path, (str, os.PathLike)):
9093
# pyre-ignore[6]: In call `os.fspath`, for 1st positional argument, expected `str` but got `Union[PathLike[typing.Any], str]`
9194
path = os.fspath(path)
9295

96+
if not (self.edge_dialect_program and self.graph_map and self._debug_handle_map):
97+
raise RuntimeError(
98+
"ETRecord must contain edge dialect program, graph map, and debug handle map to be saved."
99+
)
100+
93101
etrecord_zip = ZipFile(path, "w")
94102

95103
try:

devtools/etrecord/tests/etrecord_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,3 +1545,32 @@ def test_update_apis_and_save_parse(self):
15451545
custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"]
15461546
):
15471547
self.assertTrue(torch.equal(expected[0], actual[0]))
1548+
1549+
def test_save_missing_essential_info(self):
1550+
def expected_runtime_error(etrecord, etrecord_path):
1551+
with self.assertRaises(RuntimeError) as context:
1552+
etrecord.save(etrecord_path)
1553+
1554+
self.assertIn(
1555+
"ETRecord must contain edge dialect program, graph map, and debug handle map to be saved",
1556+
str(context.exception),
1557+
)
1558+
1559+
"""Test that save raises RuntimeError when essential info is missing."""
1560+
_, edge_output, et_output = self.get_test_model()
1561+
1562+
etrecord = ETRecord()
1563+
1564+
with tempfile.TemporaryDirectory() as tmpdirname:
1565+
etrecord_path = tmpdirname + "/etrecord_no_edge.bin"
1566+
1567+
expected_runtime_error(etrecord, etrecord_path)
1568+
etrecord.add_edge_dialect_program(edge_output)
1569+
1570+
# Should raise runtime error due to missing executorch program related info
1571+
expected_runtime_error(etrecord, etrecord_path)
1572+
1573+
etrecord.add_executorch_program(et_output)
1574+
1575+
# All essential components are now present, so save should succeed
1576+
etrecord.save(etrecord_path)

0 commit comments

Comments
 (0)