Skip to content

Commit 6b37502

Browse files
committed
make to_edge_transform_and_lower support etrecord generation
Differential Revision: [D79336982](https://our.internmc.facebook.com/intern/diff/D79336982/) [ghstack-poisoned]
1 parent 14e38bd commit 6b37502

File tree

2 files changed

+208
-2
lines changed

2 files changed

+208
-2
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ETRecordReservedFileNames,
2525
)
2626
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
27+
from executorch.exir.program._program import to_edge_transform_and_lower
2728
from torch.export import export
2829

2930

@@ -275,6 +276,167 @@ def test_etrecord_generation_with_exported_program(self):
275276
# Validate that export_graph_id matches the expected value
276277
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
277278

279+
def test_to_edge_transform_and_lower_with_etrecord_generation(self):
280+
"""Test that to_edge_transform_and_lower generates ETRecord correctly."""
281+
f = models.BasicSinMax()
282+
aten_program = export(f, f.get_random_inputs(), strict=True)
283+
284+
# Test with generate_etrecord=True
285+
edge_manager = to_edge_transform_and_lower(
286+
aten_program,
287+
generate_etrecord=True,
288+
)
289+
290+
# Verify that ETRecord was generated and attached
291+
self.assertIsNotNone(edge_manager._etrecord)
292+
etrecord = edge_manager._etrecord
293+
294+
# Verify that ETRecord has the expected data
295+
self.assertIsNotNone(etrecord.exported_program)
296+
self.assertIsNotNone(etrecord.export_graph_id)
297+
self.assertIsNotNone(etrecord.edge_dialect_program)
298+
299+
# Verify the exported program matches the input
300+
self.check_graph_closeness(
301+
etrecord.exported_program,
302+
aten_program.graph_module,
303+
)
304+
self.assertEqual(
305+
etrecord.export_graph_id,
306+
id(aten_program.graph),
307+
)
308+
309+
# Verify the edge dialect program matches the edge manager
310+
self.check_graph_closeness(
311+
etrecord.edge_dialect_program,
312+
edge_manager.exported_program().graph_module,
313+
)
314+
315+
def test_to_edge_transform_and_lower_without_etrecord_generation(self):
316+
"""Test that to_edge_transform_and_lower works correctly without ETRecord generation."""
317+
f = models.BasicSinMax()
318+
aten_program = export(f, f.get_random_inputs(), strict=True)
319+
320+
# Test with generate_etrecord=False (default)
321+
edge_manager = to_edge_transform_and_lower(aten_program)
322+
323+
# Verify that no ETRecord was generated
324+
self.assertIsNone(edge_manager._etrecord)
325+
326+
# Verify that the edge manager still works correctly
327+
self.assertIsNotNone(edge_manager.exported_program())
328+
329+
def test_get_etrecord_from_executorch_program_manager(self):
330+
"""Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method."""
331+
f = models.BasicSinMax()
332+
aten_program = export(f, f.get_random_inputs(), strict=True)
333+
334+
# Generate edge manager with ETRecord
335+
edge_manager = to_edge_transform_and_lower(
336+
aten_program,
337+
generate_etrecord=True,
338+
)
339+
340+
# Convert to executorch
341+
et_manager = edge_manager.to_executorch()
342+
343+
# Test get_etrecord method
344+
etrecord = et_manager.get_etrecord()
345+
self.assertIsNotNone(etrecord)
346+
347+
# Verify that the returned ETRecord has all expected data
348+
self.assertIsNotNone(etrecord.exported_program)
349+
self.assertIsNotNone(etrecord.export_graph_id)
350+
self.assertIsNotNone(etrecord.edge_dialect_program)
351+
self.assertIsNotNone(etrecord._debug_handle_map)
352+
self.assertIsNotNone(etrecord._delegate_map)
353+
354+
# Verify the data matches the original input
355+
self.check_graph_closeness(
356+
etrecord.exported_program,
357+
aten_program.graph_module,
358+
)
359+
self.assertEqual(
360+
etrecord.export_graph_id,
361+
id(aten_program.graph),
362+
)
363+
364+
# Verify the executorch program data matches
365+
# ETRecord stores data directly (not JSON serialized), so compare with original data
366+
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
367+
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)
368+
369+
def test_get_etrecord_from_executorch_program_manager_without_generation(self):
370+
"""Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated."""
371+
f = models.BasicSinMax()
372+
aten_program = export(f, f.get_random_inputs(), strict=True)
373+
374+
# Generate edge manager without ETRecord
375+
edge_manager = to_edge_transform_and_lower(aten_program)
376+
377+
# Verify no ETRecord on edge manager
378+
self.assertIsNone(edge_manager._etrecord)
379+
380+
# Convert to executorch
381+
et_manager = edge_manager.to_executorch()
382+
383+
# Verify no ETRecord on executorch manager
384+
self.assertIsNone(et_manager._etrecord)
385+
386+
# Test get_etrecord method should raise RuntimeError
387+
with self.assertRaises(RuntimeError) as context:
388+
et_manager.get_etrecord()
389+
390+
self.assertIn("ETRecord was not generated", str(context.exception))
391+
392+
def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
393+
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
394+
f = models.BasicSinMax()
395+
aten_program = export(f, f.get_random_inputs(), strict=True)
396+
397+
# Generate edge manager with ETRecord
398+
edge_manager = to_edge_transform_and_lower(
399+
aten_program,
400+
generate_etrecord=True,
401+
)
402+
403+
# Convert to executorch to get complete ETRecord
404+
et_manager = edge_manager.to_executorch()
405+
etrecord = et_manager.get_etrecord()
406+
407+
with tempfile.TemporaryDirectory() as tmpdirname:
408+
etrecord_path = tmpdirname + "/etrecord_flow2.bin"
409+
410+
etrecord.save(etrecord_path)
411+
412+
# Parse ETRecord back and verify
413+
parsed_etrecord = parse_etrecord(etrecord_path)
414+
415+
# Validate that all components are preserved
416+
# Note: Skip graph structure comparison due to transformation differences
417+
self.check_graph_closeness(
418+
etrecord.exported_program, parsed_etrecord.exported_program
419+
)
420+
self.check_graph_closeness(
421+
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
422+
)
423+
424+
# Validate executorch program data
425+
self.assertEqual(
426+
parsed_etrecord._debug_handle_map,
427+
json.loads(json.dumps(et_manager.debug_handle_map)),
428+
)
429+
self.assertEqual(
430+
parsed_etrecord._delegate_map,
431+
json.loads(json.dumps(et_manager.delegate_map)),
432+
)
433+
434+
# Validate export graph id
435+
self.assertEqual(
436+
parsed_etrecord.export_graph_id,
437+
id(aten_program.graph),
438+
)
439+
278440
def test_add_extra_export_modules(self):
279441
"""Test add_extra_export_modules when ETRecord already has a graph_map."""
280442
captured_output, edge_output, et_output = self.get_test_model()

exir/program/_program.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm):
291291
setattr(new_prog, node.target, t)
292292

293293

294+
def _create_empty_etrecord():
295+
# Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program).
296+
# This also ensures that etrecord-related packages do not affect the export flow.
297+
# @manual
298+
from executorch.devtools.etrecord import ETRecord
299+
300+
return ETRecord()
301+
302+
294303
def lift_constant_tensor_pass(ep):
295304
"""
296305
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
@@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners(
11031112
aten_programs: Dict[str, ExportedProgram],
11041113
config: EdgeCompileConfig,
11051114
constant_methods: Optional[Dict[str, Any]],
1115+
generate_etrecord: Optional[bool] = False,
11061116
) -> "EdgeProgramManager":
11071117
"""
11081118
Generates EdgeProgramManager for subsequent lowering to the
@@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners(
11791189
config,
11801190
list(set().union(*ops_set_to_not_decompose_by_program.values())),
11811191
)
1192+
1193+
if generate_etrecord:
1194+
etrecord = _create_empty_etrecord()
1195+
etrecord.add_exported_program(aten_programs)
1196+
etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager))
1197+
edge_manager._etrecord = etrecord
1198+
11821199
return edge_manager
11831200

11841201

@@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901
12201237
] = None,
12211238
constant_methods: Optional[Dict[str, Any]] = None,
12221239
compile_config: Optional[EdgeCompileConfig] = None,
1240+
generate_etrecord: bool = False,
12231241
) -> "EdgeProgramManager":
12241242
"""
12251243
:func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
@@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901
12601278
compile_config: An optional argument used to provide greater control over the
12611279
transformation to edge dialect process.
12621280
1281+
generate_etrecord: An optional argument used to generate an etrecord for debugging purposes.
1282+
12631283
Returns:
12641284
EdgeProgramManager
12651285
"""
@@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901
12791299
partitioner, aten_programs
12801300
)
12811301
edge_manager = _gen_edge_manager_for_partitioners(
1282-
partitioner, aten_programs, config, constant_methods
1302+
partitioner, aten_programs, config, constant_methods, generate_etrecord
12831303
)
12841304

12851305
if transform_passes is not None:
@@ -1447,6 +1467,8 @@ def __init__(
14471467
program, self._named_data_store
14481468
)
14491469

1470+
self._etrecord = None
1471+
14501472
@property
14511473
def methods(self) -> Set[str]:
14521474
"""
@@ -1643,13 +1665,19 @@ def to_executorch(
16431665
_copy_module(program.graph_module, new_gm)
16441666
execution_programs[name] = program
16451667

1646-
return ExecutorchProgramManager(
1668+
et_pm = ExecutorchProgramManager(
16471669
execution_programs,
16481670
self._config_methods,
16491671
config,
16501672
self._named_data_store.get_named_data_store_output(),
16511673
)
16521674

1675+
if self._etrecord is not None:
1676+
self._etrecord.add_executorch_program(et_pm)
1677+
et_pm._etrecord = self._etrecord
1678+
1679+
return et_pm
1680+
16531681

16541682
class ExecutorchProgramManager:
16551683
"""
@@ -1713,6 +1741,7 @@ def __init__(
17131741
self._named_data,
17141742
)
17151743
self._buffer: Optional[bytes] = None
1744+
self._etrecord = None
17161745

17171746
@property
17181747
def methods(self) -> Set[str]:
@@ -1785,6 +1814,21 @@ def buffer(self) -> bytes:
17851814
self._buffer = bytes(self._pte_data)
17861815
return self._buffer
17871816

1817+
def get_etrecord(self):
1818+
"""
1819+
Get the generated ETRecord if etrecord generation was enabled.
1820+
1821+
Returns:
1822+
ETRecord object if generation was enabled, None otherwise
1823+
1824+
Raises:
1825+
RuntimeError: if ETRecord object was not generated.
1826+
"""
1827+
1828+
if self._etrecord is None:
1829+
raise RuntimeError("ETRecord was not generated")
1830+
return self._etrecord
1831+
17881832
def write_to_file(self, open_file: io.BufferedIOBase) -> None:
17891833
"""
17901834
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over

0 commit comments

Comments
 (0)