Skip to content

Commit e93c139

Browse files
authored
make et.export support etrecord generation
Differential Revision: D79741917 Pull Request resolved: #13303
1 parent a84b3c9 commit e93c139

File tree

6 files changed

+161
-0
lines changed

6 files changed

+161
-0
lines changed

devtools/etrecord/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ python_library(
2222
"//executorch/exir:lib",
2323
"//executorch/exir/tests:models",
2424
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
25+
"//executorch/export:lib",
2526
],
2627
)

devtools/etrecord/tests/etrecord_test.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
)
2828
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
2929
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
30+
31+
from executorch.export import export as etexport, ExportRecipe, StageType
3032
from torch.export import export
3133

3234

@@ -137,6 +139,33 @@ def get_test_model_with_bundled_program(self):
137139
bundled_program = BundledProgram(et_output, method_test_suites)
138140
return (aten_dialect, edge_program_copy, bundled_program)
139141

142+
def get_test_export_session(self, generate_etrecord=False, to_edge_flow=False):
143+
f = models.BasicSinMax()
144+
example_inputs = [f.get_random_inputs()]
145+
export_recipe = None
146+
147+
if to_edge_flow:
148+
export_recipe = ExportRecipe(
149+
pipeline_stages=[
150+
StageType.TORCH_EXPORT,
151+
StageType.TO_EDGE,
152+
StageType.TO_BACKEND,
153+
StageType.TO_EXECUTORCH,
154+
]
155+
)
156+
else:
157+
export_recipe = ExportRecipe()
158+
159+
# Test with generate_etrecord=True
160+
export_session = etexport(
161+
model=f,
162+
example_inputs=example_inputs,
163+
export_recipe=export_recipe,
164+
generate_etrecord=generate_etrecord,
165+
)
166+
167+
return export_session
168+
140169
# Serialized and deserialized graph modules are not completely the same, so we check
141170
# that they are close enough and match especially on the parameters we care about in the Developer Tools.
142171
def check_graph_closeness(self, graph_a, graph_b):
@@ -1298,6 +1327,113 @@ def test_add_all_programs_sequentially(self):
12981327
json.loads(json.dumps(et_output.delegate_map)),
12991328
)
13001329

1330+
def test_executorch_export_with_etrecord_generation(self):
1331+
"""Test that executorch.export generates ETRecord correctly when generate_etrecord=True."""
1332+
# Verify that ETRecord was generated and can be retrieved
1333+
export_session = self.get_test_export_session(generate_etrecord=True)
1334+
etrecord = export_session.get_etrecord()
1335+
self.assertIsNotNone(etrecord)
1336+
self.assert_etrecord_saveable(etrecord)
1337+
1338+
# Verify the executorch program data matches
1339+
et_manager = export_session.get_executorch_program_manager()
1340+
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
1341+
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)
1342+
1343+
def test_executorch_export_without_etrecord_generation(self):
1344+
"""Test that executorch.export works correctly without ETRecord generation."""
1345+
# Test with generate_etrecord=False (default)
1346+
export_session = self.get_test_export_session(generate_etrecord=False)
1347+
1348+
# Verify that no ETRecord was generated
1349+
with self.assertRaises(RuntimeError) as context:
1350+
export_session.get_etrecord()
1351+
1352+
self.assertIn("ETRecord was not generated", str(context.exception))
1353+
1354+
# Verify that the export session still works correctly
1355+
self.assertIsNotNone(export_session.get_executorch_program_manager())
1356+
self.assertTrue(len(export_session.get_pte_buffer()) > 0)
1357+
1358+
def test_executorch_export_etrecord_save_and_parse(self):
1359+
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
1360+
export_session = self.get_test_export_session(generate_etrecord=True)
1361+
1362+
etrecord = export_session.get_etrecord()
1363+
1364+
with tempfile.TemporaryDirectory() as tmpdirname:
1365+
etrecord_path = tmpdirname + "/etrecord_export.bin"
1366+
1367+
etrecord.save(etrecord_path)
1368+
1369+
# Parse ETRecord back and verify
1370+
parsed_etrecord = parse_etrecord(etrecord_path)
1371+
1372+
# Validate that all components are preserved
1373+
self.assertIsNotNone(parsed_etrecord.exported_program)
1374+
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)
1375+
1376+
# Validate executorch program data
1377+
et_manager = export_session.get_executorch_program_manager()
1378+
self.assertEqual(
1379+
parsed_etrecord._debug_handle_map,
1380+
json.loads(json.dumps(et_manager.debug_handle_map)),
1381+
)
1382+
self.assertEqual(
1383+
parsed_etrecord._delegate_map,
1384+
json.loads(json.dumps(et_manager.delegate_map)),
1385+
)
1386+
1387+
# Validate export graph id is preserved
1388+
self.assertIsNotNone(parsed_etrecord.export_graph_id)
1389+
1390+
def test_executorch_export_with_to_edge_flow(self):
1391+
"""Test executorch.export with TO_EDGE flow and ETRecord generation."""
1392+
export_session = self.get_test_export_session(
1393+
generate_etrecord=True,
1394+
to_edge_flow=True,
1395+
)
1396+
1397+
# Verify that ETRecord was generated
1398+
etrecord = export_session.get_etrecord()
1399+
self.assertIsNotNone(etrecord)
1400+
self.assert_etrecord_saveable(etrecord)
1401+
1402+
def test_executorch_export_etrecord_with_to_edge_flow_save_and_parse(self):
1403+
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
1404+
export_session = self.get_test_export_session(
1405+
generate_etrecord=True,
1406+
to_edge_flow=True,
1407+
)
1408+
1409+
etrecord = export_session.get_etrecord()
1410+
1411+
with tempfile.TemporaryDirectory() as tmpdirname:
1412+
etrecord_path = tmpdirname + "/etrecord_export.bin"
1413+
1414+
etrecord.save(etrecord_path)
1415+
1416+
# Parse ETRecord back and verify
1417+
parsed_etrecord = parse_etrecord(etrecord_path)
1418+
1419+
# Validate that all components are preserved
1420+
self.assertIsNotNone(parsed_etrecord.exported_program)
1421+
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)
1422+
1423+
# Validate executorch program data
1424+
et_manager = export_session.get_executorch_program_manager()
1425+
self.assertEqual(
1426+
parsed_etrecord._debug_handle_map,
1427+
json.loads(json.dumps(et_manager.debug_handle_map)),
1428+
)
1429+
self.assertEqual(
1430+
parsed_etrecord._delegate_map,
1431+
json.loads(json.dumps(et_manager.delegate_map)),
1432+
)
1433+
1434+
# Validate export graph id is preserved
1435+
self.assertIsNotNone(parsed_etrecord.export_graph_id)
1436+
13011437
def test_update_representative_inputs_with_list(self):
13021438
"""Test update_representative_inputs with a list of ProgramInput objects."""
13031439
captured_output, edge_output, et_output = self.get_test_model()

export/export.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def export(
4444
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
4545
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
4646
artifact_dir: Optional[str] = None,
47+
generate_etrecord: bool = False,
4748
) -> "ExportSession":
4849
"""
4950
Create and configure an ExportSession with the given parameters.
@@ -61,6 +62,7 @@ def export(
6162
dynamic_shapes: Optional dynamic shape specifications
6263
constant_methods: Optional dictionary of constant methods
6364
artifact_dir: Optional directory to store artifacts
65+
generate_etrecord: Optional flag to generate an etrecord
6466
6567
Returns:
6668
A configured ExportSession instance with the export process completed if requested
@@ -73,6 +75,7 @@ def export(
7375
dynamic_shapes=dynamic_shapes,
7476
constant_methods=constant_methods,
7577
artifact_dir=artifact_dir,
78+
generate_etrecord=generate_etrecord,
7679
)
7780
session.export()
7881

@@ -104,6 +107,7 @@ def __init__(
104107
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
105108
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
106109
artifact_dir: Optional[str] = None,
110+
generate_etrecord: Optional[bool] = False,
107111
) -> None:
108112
"""
109113
Initialize the ExportSession with model, inputs, and recipe.
@@ -118,6 +122,7 @@ def __init__(
118122
dynamic_shapes: Optional dynamic shape specifications
119123
constant_methods: Optional dictionary of constant methods
120124
artifact_dir: Optional directory to store artifacts
125+
generate_etrecord: Optional flag to generate an etrecord
121126
"""
122127
# Standardize model to dictionary format
123128
self._model = model if isinstance(model, dict) else {"forward": model}
@@ -165,6 +170,7 @@ def __init__(
165170
"export_recipe": self._export_recipe,
166171
"session_name": name,
167172
"artifact_dir": artifact_dir,
173+
"generate_etrecord": generate_etrecord,
168174
}
169175

170176
self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}
@@ -467,3 +473,16 @@ def print_delegation_info(self) -> None:
467473
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))
468474
else:
469475
print("No delegation info available")
476+
477+
# Use Any instead of ETRecord as return type to avoid static dependency on etrecord
478+
def get_etrecord(self) -> Any:
479+
"""
480+
Get the etrecord from the ExecuTorchProgramManager.
481+
482+
Returns:
483+
The etrecord in the ExecuTorchProgramManager
484+
485+
Raises:
486+
RuntimeError: If the ExecuTorchManager is unavailable, or etrecord is not available in the ExecuTorchProgramManager
487+
"""
488+
return self.get_executorch_program_manager().get_etrecord()

export/stages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def run(self, artifact: PipelineArtifact) -> None:
199199
"""
200200
exported_programs = artifact.data
201201
constant_methods = artifact.get_context("constant_methods")
202+
generate_etrecord = artifact.get_context("generate_etrecord", False)
202203

203204
with validation_disabled():
204205
edge_program_manager = to_edge_transform_and_lower(
@@ -207,6 +208,7 @@ def run(self, artifact: PipelineArtifact) -> None:
207208
transform_passes=self._transform_passes,
208209
constant_methods=constant_methods,
209210
compile_config=self._compile_config,
211+
generate_etrecord=generate_etrecord,
210212
)
211213

212214
delegation_info = get_delegation_info(
@@ -418,6 +420,7 @@ def run(self, artifact: PipelineArtifact) -> None:
418420
exported_programs,
419421
constant_methods=constant_methods,
420422
compile_config=self._edge_compile_config,
423+
generate_etrecord=artifact.get_context("generate_etrecord", False),
421424
)
422425

423426
self._artifact = artifact.copy_with_new_data(edge_program_manager)

export/tests/test_export_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_context_propagation_through_pipeline(self) -> None:
184184
"export_recipe",
185185
"session_name",
186186
"artifact_dir",
187+
"generate_etrecord",
187188
}
188189
self.assertEqual(set(session._run_context.keys()), expected_context_keys)
189190
self.assertEqual(session._run_context["session_name"], "test_session")

export/tests/test_export_stages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_run_success(self, mock_to_edge: Mock) -> None:
307307
self.exported_programs,
308308
constant_methods=None,
309309
compile_config=mock_config,
310+
generate_etrecord=False,
310311
)
311312

312313
# Verify artifacts are set correctly

0 commit comments

Comments
 (0)