Skip to content

Commit 273bd0a

Browse files
committed
make et.export support etrecord generation
this diff makes et.export etrecord generation supportive. Details can be found in #12925. After this change, all things in #12925 has completed. Differential Revision: [D79741917](https://our.internmc.facebook.com/intern/diff/D79741917/) [ghstack-poisoned]
1 parent 07acaad commit 273bd0a

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

devtools/etrecord/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python_unittest(
1313
"//executorch/devtools/etrecord:etrecord",
1414
"//executorch/exir:lib",
1515
"//executorch/exir/tests:models",
16+
"//executorch/export:lib",
1617
],
1718
)
1819

devtools/etrecord/tests/etrecord_test.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
)
2727
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
2828
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
29+
30+
from executorch.export import export as etexport, ExportRecipe
31+
from executorch.export.types import StageType
2932
from torch.export import export
3033

3134

@@ -136,6 +139,33 @@ def get_test_model_with_bundled_program(self):
136139
bundled_program = BundledProgram(et_output, method_test_suites)
137140
return (aten_dialect, edge_program_copy, bundled_program)
138141

142+
def get_test_export_session(self, generate_etrecord=False, to_edge_flow=False):
143+
f = models.BasicSinMax()
144+
example_inputs = [f.get_random_inputs()]
145+
export_recipe = None
146+
147+
if to_edge_flow:
148+
export_recipe = ExportRecipe(
149+
pipeline_stages=[
150+
StageType.TORCH_EXPORT,
151+
StageType.TO_EDGE,
152+
StageType.TO_BACKEND,
153+
StageType.TO_EXECUTORCH,
154+
]
155+
)
156+
else:
157+
export_recipe = ExportRecipe()
158+
159+
# Test with generate_etrecord=True
160+
export_session = etexport(
161+
model=f,
162+
example_inputs=example_inputs,
163+
export_recipe=export_recipe,
164+
generate_etrecord=generate_etrecord,
165+
)
166+
167+
return export_session
168+
139169
# Serialized and deserialized graph modules are not completely the same, so we check
140170
# that they are close enough and match especially on the parameters we care about in the Developer Tools.
141171
def check_graph_closeness(self, graph_a, graph_b):
@@ -1261,6 +1291,113 @@ def test_add_all_programs_sequentially(self):
12611291
json.loads(json.dumps(et_output.delegate_map)),
12621292
)
12631293

1294+
def test_executorch_export_with_etrecord_generation(self):
1295+
"""Test that executorch.export generates ETRecord correctly when generate_etrecord=True."""
1296+
# Verify that ETRecord was generated and can be retrieved
1297+
export_session = self.get_test_export_session(generate_etrecord=True)
1298+
etrecord = export_session.get_etrecord()
1299+
self.assertIsNotNone(etrecord)
1300+
self.assert_etrecord_saveable(etrecord)
1301+
1302+
# Verify the executorch program data matches
1303+
et_manager = export_session.get_executorch_program_manager()
1304+
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
1305+
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)
1306+
1307+
def test_executorch_export_without_etrecord_generation(self):
1308+
"""Test that executorch.export works correctly without ETRecord generation."""
1309+
# Test with generate_etrecord=False (default)
1310+
export_session = self.get_test_export_session(generate_etrecord=False)
1311+
1312+
# Verify that no ETRecord was generated
1313+
with self.assertRaises(RuntimeError) as context:
1314+
export_session.get_etrecord()
1315+
1316+
self.assertIn("ETRecord was not generated", str(context.exception))
1317+
1318+
# Verify that the export session still works correctly
1319+
self.assertIsNotNone(export_session.get_executorch_program_manager())
1320+
self.assertTrue(len(export_session.get_pte_buffer()) > 0)
1321+
1322+
def test_executorch_export_etrecord_save_and_parse(self):
1323+
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
1324+
export_session = self.get_test_export_session(generate_etrecord=True)
1325+
1326+
etrecord = export_session.get_etrecord()
1327+
1328+
with tempfile.TemporaryDirectory() as tmpdirname:
1329+
etrecord_path = tmpdirname + "/etrecord_export.bin"
1330+
1331+
etrecord.save(etrecord_path)
1332+
1333+
# Parse ETRecord back and verify
1334+
parsed_etrecord = parse_etrecord(etrecord_path)
1335+
1336+
# Validate that all components are preserved
1337+
self.assertIsNotNone(parsed_etrecord.exported_program)
1338+
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)
1339+
1340+
# Validate executorch program data
1341+
et_manager = export_session.get_executorch_program_manager()
1342+
self.assertEqual(
1343+
parsed_etrecord._debug_handle_map,
1344+
json.loads(json.dumps(et_manager.debug_handle_map)),
1345+
)
1346+
self.assertEqual(
1347+
parsed_etrecord._delegate_map,
1348+
json.loads(json.dumps(et_manager.delegate_map)),
1349+
)
1350+
1351+
# Validate export graph id is preserved
1352+
self.assertIsNotNone(parsed_etrecord.export_graph_id)
1353+
1354+
def test_executorch_export_with_to_edge_flow(self):
1355+
"""Test executorch.export with TO_EDGE flow and ETRecord generation."""
1356+
export_session = self.get_test_export_session(
1357+
generate_etrecord=True,
1358+
to_edge_flow=True,
1359+
)
1360+
1361+
# Verify that ETRecord was generated
1362+
etrecord = export_session.get_etrecord()
1363+
self.assertIsNotNone(etrecord)
1364+
self.assert_etrecord_saveable(etrecord)
1365+
1366+
def test_executorch_export_etrecord_with_to_edge_flow_save_and_parse(self):
1367+
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
1368+
export_session = self.get_test_export_session(
1369+
generate_etrecord=True,
1370+
to_edge_flow=True,
1371+
)
1372+
1373+
etrecord = export_session.get_etrecord()
1374+
1375+
with tempfile.TemporaryDirectory() as tmpdirname:
1376+
etrecord_path = tmpdirname + "/etrecord_export.bin"
1377+
1378+
etrecord.save(etrecord_path)
1379+
1380+
# Parse ETRecord back and verify
1381+
parsed_etrecord = parse_etrecord(etrecord_path)
1382+
1383+
# Validate that all components are preserved
1384+
self.assertIsNotNone(parsed_etrecord.exported_program)
1385+
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)
1386+
1387+
# Validate executorch program data
1388+
et_manager = export_session.get_executorch_program_manager()
1389+
self.assertEqual(
1390+
parsed_etrecord._debug_handle_map,
1391+
json.loads(json.dumps(et_manager.debug_handle_map)),
1392+
)
1393+
self.assertEqual(
1394+
parsed_etrecord._delegate_map,
1395+
json.loads(json.dumps(et_manager.delegate_map)),
1396+
)
1397+
1398+
# Validate export graph id is preserved
1399+
self.assertIsNotNone(parsed_etrecord.export_graph_id)
1400+
12641401
def test_update_representative_inputs_with_list(self):
12651402
"""Test update_representative_inputs with a list of ProgramInput objects."""
12661403
captured_output, edge_output, et_output = self.get_test_model()

export/export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def export(
4444
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
4545
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
4646
artifact_dir: Optional[str] = None,
47+
generate_etrecord: Optional[bool] = False,
4748
) -> "ExportSession":
4849
"""
4950
Create and configure an ExportSession with the given parameters.
@@ -61,6 +62,7 @@ def export(
6162
dynamic_shapes: Optional dynamic shape specifications
6263
constant_methods: Optional dictionary of constant methods
6364
artifact_dir: Optional directory to store artifacts
65+
generate_etrecord: Optional flag to generate an etrecord
6466
6567
Returns:
6668
A configured ExportSession instance with the export process completed if requested
@@ -73,6 +75,7 @@ def export(
7375
dynamic_shapes=dynamic_shapes,
7476
constant_methods=constant_methods,
7577
artifact_dir=artifact_dir,
78+
generate_etrecord=generate_etrecord,
7679
)
7780
session.export()
7881

@@ -104,6 +107,7 @@ def __init__(
104107
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
105108
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
106109
artifact_dir: Optional[str] = None,
110+
generate_etrecord: Optional[bool] = False,
107111
) -> None:
108112
"""
109113
Initialize the ExportSession with model, inputs, and recipe.
@@ -118,6 +122,7 @@ def __init__(
118122
dynamic_shapes: Optional dynamic shape specifications
119123
constant_methods: Optional dictionary of constant methods
120124
artifact_dir: Optional directory to store artifacts
125+
generate_etrecord: Optional flag to generate an etrecord
121126
"""
122127
# Standardize model to dictionary format
123128
self._model = model if isinstance(model, dict) else {"forward": model}
@@ -165,6 +170,7 @@ def __init__(
165170
"export_recipe": self._export_recipe,
166171
"session_name": name,
167172
"artifact_dir": artifact_dir,
173+
"generate_etrecord": generate_etrecord,
168174
}
169175

170176
self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}
@@ -453,3 +459,16 @@ def print_delegation_info(self) -> None:
453459
logging.info(tabulate(df, headers="keys", tablefmt="fancy_grid"))
454460
else:
455461
logging.info("No delegation info available")
462+
463+
# Use Any instead of ETRecord as return type to avoid static dependency on etrecord
464+
def get_etrecord(self) -> Any:
465+
"""
466+
Get the etrecord from the ExecuTorchProgramManager.
467+
468+
Returns:
469+
The etrecord in the ExecuTorchProgramManager
470+
471+
Raises:
472+
RuntimeError: If the ExecuTorchManager is unavailable, or etrecord is not available in the ExecuTorchProgramManager
473+
"""
474+
return self.get_executorch_program_manager().get_etrecord()

export/stages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def run(self, artifact: PipelineArtifact) -> None:
199199
"""
200200
exported_programs = artifact.data
201201
constant_methods = artifact.get_context("constant_methods")
202+
generate_etrecord = artifact.get_context("generate_etrecord", False)
202203

203204
with validation_disabled():
204205
edge_program_manager = to_edge_transform_and_lower(
@@ -207,6 +208,7 @@ def run(self, artifact: PipelineArtifact) -> None:
207208
transform_passes=self._transform_passes,
208209
constant_methods=constant_methods,
209210
compile_config=self._compile_config,
211+
generate_etrecord=generate_etrecord,
210212
)
211213

212214
delegation_info = get_delegation_info(
@@ -418,6 +420,7 @@ def run(self, artifact: PipelineArtifact) -> None:
418420
exported_programs,
419421
constant_methods=constant_methods,
420422
compile_config=self._edge_compile_config,
423+
generate_etrecord=artifact.get_context("generate_etrecord", False),
421424
)
422425

423426
self._artifact = artifact.copy_with_new_data(edge_program_manager)

0 commit comments

Comments
 (0)