Skip to content

Commit 9caa80d

Browse files
committed
make to_edge_transform_and_lower support etrecord generation
Pull Request resolved: #13034 This diff focus on making `to_edge_transform_and_lower` pipeline support etrecord generation, which can not work on etrecord generation at all previously. Also add tests into it. ghstack-source-id: 300161201 @exported-using-ghexport Differential Revision: [D79336982](https://our.internmc.facebook.com/intern/diff/D79336982/)
1 parent f4fe072 commit 9caa80d

File tree

2 files changed

+213
-2
lines changed

2 files changed

+213
-2
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ETRecordReservedFileNames,
2525
)
2626
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
27+
from executorch.exir.program._program import to_edge_transform_and_lower
2728
from torch.export import export
2829

2930

@@ -52,6 +53,21 @@ def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None:
5253
self.assert_etrecord_has_no_executorch_program(etrecord)
5354
self.assertIsNone(etrecord.graph_map)
5455

56+
def assert_legal_etrecord_in_edge_program(self, etrecord: ETRecord) -> None:
57+
"""Assert that ETRecord has all expected data after to_edge_transform_and_lower() or to_edge() stage"""
58+
self.assertIsNotNone(etrecord.exported_program)
59+
self.assertIsNotNone(etrecord.export_graph_id)
60+
self.assertIsNotNone(etrecord.edge_dialect_program)
61+
self.assert_etrecord_has_no_executorch_program(etrecord)
62+
63+
def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
64+
"""Assert ETRecord contains all essential information for saving"""
65+
self.assertIsNotNone(etrecord.exported_program)
66+
self.assertIsNotNone(etrecord.export_graph_id)
67+
self.assertIsNotNone(etrecord.edge_dialect_program)
68+
self.assertIsNotNone(etrecord._debug_handle_map)
69+
self.assertIsNotNone(etrecord._delegate_map)
70+
5571
def get_test_model(self):
5672
f = models.BasicSinMax()
5773
captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
@@ -275,6 +291,157 @@ def test_etrecord_generation_with_exported_program(self):
275291
# Validate that export_graph_id matches the expected value
276292
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
277293

294+
def test_to_edge_transform_and_lower_with_etrecord_generation(self):
295+
"""Test that to_edge_transform_and_lower generates ETRecord correctly."""
296+
f = models.BasicSinMax()
297+
aten_program = export(f, f.get_random_inputs(), strict=True)
298+
299+
# Test with generate_etrecord=True
300+
edge_manager = to_edge_transform_and_lower(
301+
aten_program,
302+
generate_etrecord=True,
303+
)
304+
305+
# Verify that ETRecord was generated and attached
306+
self.assertIsNotNone(edge_manager._etrecord)
307+
etrecord = edge_manager._etrecord
308+
self.assert_legal_etrecord_in_edge_program(etrecord)
309+
310+
# Verify the exported program matches the input
311+
self.check_graph_closeness(
312+
etrecord.exported_program,
313+
aten_program.graph_module,
314+
)
315+
self.assertEqual(
316+
etrecord.export_graph_id,
317+
id(aten_program.graph),
318+
)
319+
320+
# Verify the edge dialect program matches the edge manager
321+
self.check_graph_closeness(
322+
etrecord.edge_dialect_program,
323+
edge_manager.exported_program().graph_module,
324+
)
325+
326+
def test_to_edge_transform_and_lower_without_etrecord_generation(self):
327+
"""Test that to_edge_transform_and_lower works correctly without ETRecord generation."""
328+
f = models.BasicSinMax()
329+
aten_program = export(f, f.get_random_inputs(), strict=True)
330+
331+
# Test with generate_etrecord=False (default)
332+
edge_manager = to_edge_transform_and_lower(aten_program)
333+
334+
# Verify that no ETRecord was generated
335+
self.assertIsNone(edge_manager._etrecord)
336+
337+
# Verify that the edge manager still works correctly
338+
self.assertIsNotNone(edge_manager.exported_program())
339+
340+
def test_get_etrecord_from_executorch_program_manager(self):
341+
"""Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method."""
342+
f = models.BasicSinMax()
343+
aten_program = export(f, f.get_random_inputs(), strict=True)
344+
345+
# Generate edge manager with ETRecord
346+
edge_manager = to_edge_transform_and_lower(
347+
aten_program,
348+
generate_etrecord=True,
349+
)
350+
351+
# Convert to executorch
352+
et_manager = edge_manager.to_executorch()
353+
354+
# Test get_etrecord method
355+
etrecord = et_manager.get_etrecord()
356+
self.assertIsNotNone(etrecord)
357+
self.assert_etrecord_saveable(etrecord)
358+
359+
# Verify the data matches the original input
360+
self.check_graph_closeness(
361+
etrecord.exported_program,
362+
aten_program.graph_module,
363+
)
364+
self.assertEqual(
365+
etrecord.export_graph_id,
366+
id(aten_program.graph),
367+
)
368+
369+
# Verify the executorch program data matches
370+
# ETRecord stores data directly (not JSON serialized), so compare with original data
371+
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
372+
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)
373+
374+
def test_get_etrecord_from_executorch_program_manager_without_generation(self):
375+
"""Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated."""
376+
f = models.BasicSinMax()
377+
aten_program = export(f, f.get_random_inputs(), strict=True)
378+
379+
# Generate edge manager without ETRecord
380+
edge_manager = to_edge_transform_and_lower(aten_program)
381+
382+
# Verify no ETRecord on edge manager
383+
self.assertIsNone(edge_manager._etrecord)
384+
385+
# Convert to executorch
386+
et_manager = edge_manager.to_executorch()
387+
388+
# Verify no ETRecord on executorch manager
389+
self.assertIsNone(et_manager._etrecord)
390+
391+
# Test get_etrecord method should raise RuntimeError
392+
with self.assertRaises(RuntimeError) as context:
393+
et_manager.get_etrecord()
394+
395+
self.assertIn("ETRecord was not generated", str(context.exception))
396+
397+
def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
398+
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
399+
f = models.BasicSinMax()
400+
aten_program = export(f, f.get_random_inputs(), strict=True)
401+
402+
# Generate edge manager with ETRecord
403+
edge_manager = to_edge_transform_and_lower(
404+
aten_program,
405+
generate_etrecord=True,
406+
)
407+
408+
# Convert to executorch to get complete ETRecord
409+
et_manager = edge_manager.to_executorch()
410+
etrecord = et_manager.get_etrecord()
411+
412+
with tempfile.TemporaryDirectory() as tmpdirname:
413+
etrecord_path = tmpdirname + "/etrecord_flow2.bin"
414+
415+
etrecord.save(etrecord_path)
416+
417+
# Parse ETRecord back and verify
418+
parsed_etrecord = parse_etrecord(etrecord_path)
419+
420+
# Validate that all components are preserved
421+
# Note: Skip graph structure comparison due to transformation differences
422+
self.check_graph_closeness(
423+
etrecord.exported_program, parsed_etrecord.exported_program
424+
)
425+
self.check_graph_closeness(
426+
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
427+
)
428+
429+
# Validate executorch program data
430+
self.assertEqual(
431+
parsed_etrecord._debug_handle_map,
432+
json.loads(json.dumps(et_manager.debug_handle_map)),
433+
)
434+
self.assertEqual(
435+
parsed_etrecord._delegate_map,
436+
json.loads(json.dumps(et_manager.delegate_map)),
437+
)
438+
439+
# Validate export graph id
440+
self.assertEqual(
441+
parsed_etrecord.export_graph_id,
442+
id(aten_program.graph),
443+
)
444+
278445
def test_add_extra_export_modules(self):
279446
"""Test add_extra_export_modules when ETRecord already has a graph_map."""
280447
captured_output, edge_output, et_output = self.get_test_model()

exir/program/_program.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm):
291291
setattr(new_prog, node.target, t)
292292

293293

294+
def _create_empty_etrecord():
295+
# Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program).
296+
# This also ensures that etrecord-related packages do not affect the export flow.
297+
# @manual
298+
from executorch.devtools.etrecord import ETRecord
299+
300+
return ETRecord()
301+
302+
294303
def lift_constant_tensor_pass(ep):
295304
"""
296305
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
@@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners(
11031112
aten_programs: Dict[str, ExportedProgram],
11041113
config: EdgeCompileConfig,
11051114
constant_methods: Optional[Dict[str, Any]],
1115+
generate_etrecord: Optional[bool] = False,
11061116
) -> "EdgeProgramManager":
11071117
"""
11081118
Generates EdgeProgramManager for subsequent lowering to the
@@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners(
11791189
config,
11801190
list(set().union(*ops_set_to_not_decompose_by_program.values())),
11811191
)
1192+
1193+
if generate_etrecord:
1194+
etrecord = _create_empty_etrecord()
1195+
etrecord.add_exported_program(aten_programs)
1196+
etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager))
1197+
edge_manager._etrecord = etrecord
1198+
11821199
return edge_manager
11831200

11841201

@@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901
12201237
] = None,
12211238
constant_methods: Optional[Dict[str, Any]] = None,
12221239
compile_config: Optional[EdgeCompileConfig] = None,
1240+
generate_etrecord: bool = False,
12231241
) -> "EdgeProgramManager":
12241242
"""
12251243
:func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
@@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901
12601278
compile_config: An optional argument used to provide greater control over the
12611279
transformation to edge dialect process.
12621280
1281+
generate_etrecord: An optional argument used to generate an etrecord for debugging purposes.
1282+
12631283
Returns:
12641284
EdgeProgramManager
12651285
"""
@@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901
12791299
partitioner, aten_programs
12801300
)
12811301
edge_manager = _gen_edge_manager_for_partitioners(
1282-
partitioner, aten_programs, config, constant_methods
1302+
partitioner, aten_programs, config, constant_methods, generate_etrecord
12831303
)
12841304

12851305
if transform_passes is not None:
@@ -1447,6 +1467,8 @@ def __init__(
14471467
program, self._named_data_store
14481468
)
14491469

1470+
self._etrecord = None
1471+
14501472
@property
14511473
def methods(self) -> Set[str]:
14521474
"""
@@ -1643,13 +1665,19 @@ def to_executorch(
16431665
_copy_module(program.graph_module, new_gm)
16441666
execution_programs[name] = program
16451667

1646-
return ExecutorchProgramManager(
1668+
et_pm = ExecutorchProgramManager(
16471669
execution_programs,
16481670
self._config_methods,
16491671
config,
16501672
self._named_data_store.get_named_data_store_output(),
16511673
)
16521674

1675+
if self._etrecord is not None:
1676+
self._etrecord.add_executorch_program(et_pm)
1677+
et_pm._etrecord = self._etrecord
1678+
1679+
return et_pm
1680+
16531681

16541682
class ExecutorchProgramManager:
16551683
"""
@@ -1713,6 +1741,7 @@ def __init__(
17131741
self._named_data,
17141742
)
17151743
self._buffer: Optional[bytes] = None
1744+
self._etrecord = None
17161745

17171746
@property
17181747
def methods(self) -> Set[str]:
@@ -1785,6 +1814,21 @@ def buffer(self) -> bytes:
17851814
self._buffer = bytes(self._pte_data)
17861815
return self._buffer
17871816

1817+
def get_etrecord(self):
1818+
"""
1819+
Get the generated ETRecord if etrecord generation was enabled.
1820+
1821+
Returns:
1822+
ETRecord object if generation was enabled, None otherwise
1823+
1824+
Raises:
1825+
RuntimeError: if ETRecord object was not generated.
1826+
"""
1827+
1828+
if self._etrecord is None:
1829+
raise RuntimeError("ETRecord was not generated")
1830+
return self._etrecord
1831+
17881832
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
17891833
"""
17901834
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over

0 commit comments

Comments
 (0)