Skip to content

Commit 6fd97ab

Browse files
authored
make to_edge support etrecord generation
Differential Revision: D79707919 Pull Request resolved: #13244
1 parent bb66af0 commit 6fd97ab

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
ETRecord,
2626
ETRecordReservedFileNames,
2727
)
28-
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
29-
from executorch.exir.program._program import to_edge_transform_and_lower
28+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
29+
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
3030
from torch.export import export
3131

3232

@@ -106,11 +106,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
106106
self.assertIsNotNone(etrecord._debug_handle_map)
107107
self.assertIsNotNone(etrecord._delegate_map)
108108

109-
def get_test_model(self):
109+
def get_test_model(self, generate_etrecord=False):
110110
f = models.BasicSinMax()
111111
aten_dialect = export(f, f.get_random_inputs(), strict=True)
112112
edge_program: EdgeProgramManager = to_edge(
113-
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
113+
aten_dialect,
114+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
115+
generate_etrecord=generate_etrecord,
114116
)
115117
edge_program_copy = copy.deepcopy(edge_program)
116118
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
@@ -428,6 +430,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self):
428430

429431
self.assertIn("ETRecord was not generated", str(context.exception))
430432

433+
def test_to_edge_with_etrecord_generation(self):
434+
"""Test that to_edge generates ETRecord correctly."""
435+
aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True)
436+
437+
# Verify that ETRecord was generated and attached
438+
self.assertIsNotNone(edge_manager._etrecord)
439+
etrecord = edge_manager._etrecord
440+
self.assert_legal_etrecord_in_edge_program(etrecord)
441+
442+
# Verify the exported program matches the input
443+
self.check_graph_closeness(
444+
etrecord.exported_program,
445+
aten_program.graph_module,
446+
)
447+
self.assertEqual(
448+
etrecord.export_graph_id,
449+
id(aten_program.graph),
450+
)
451+
452+
# Verify the edge dialect program matches the edge manager
453+
self.check_graph_closeness(
454+
etrecord.edge_dialect_program,
455+
edge_manager.exported_program().graph_module,
456+
)
457+
458+
def test_to_edge_without_etrecord_generation(self):
459+
"""Test that to_edge works correctly without ETRecord generation."""
460+
# Test with generate_etrecord=False (default)
461+
_, edge_manager, et_manager = self.get_test_model()
462+
463+
# Verify that no ETRecord was generated
464+
self.assertIsNone(edge_manager._etrecord)
465+
466+
# Test get_etrecord method should raise RuntimeError
467+
with self.assertRaises(RuntimeError):
468+
et_manager.get_etrecord()
469+
470+
def test_to_edge_etrecord_save_and_parse(self):
471+
"""Test that ETRecord generated by to_edge can be saved and parsed."""
472+
aten_program, _, et_manager = self.get_test_model(generate_etrecord=True)
473+
474+
etrecord = et_manager.get_etrecord()
475+
476+
with tempfile.TemporaryDirectory() as tmpdirname:
477+
etrecord_path = tmpdirname + "/etrecord_to_edge.bin"
478+
479+
etrecord.save(etrecord_path)
480+
481+
# Parse ETRecord back and verify
482+
parsed_etrecord = parse_etrecord(etrecord_path)
483+
484+
# Validate that all components are preserved
485+
# Note: Skip graph structure comparison due to transformation differences
486+
self.check_graph_closeness(
487+
etrecord.exported_program, parsed_etrecord.exported_program
488+
)
489+
self.check_graph_closeness(
490+
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
491+
)
492+
493+
# Validate executorch program data
494+
self.assertEqual(
495+
parsed_etrecord._debug_handle_map,
496+
json.loads(json.dumps(et_manager.debug_handle_map)),
497+
)
498+
self.assertEqual(
499+
parsed_etrecord._delegate_map,
500+
json.loads(json.dumps(et_manager.delegate_map)),
501+
)
502+
503+
# Validate export graph id
504+
self.assertEqual(
505+
parsed_etrecord.export_graph_id,
506+
id(aten_program.graph),
507+
)
508+
431509
def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
432510
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
433511
f = models.BasicSinMax()

exir/program/_program.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,7 @@ def to_edge(
13761376
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
13771377
constant_methods: Optional[Dict[str, Any]] = None,
13781378
compile_config: Optional[EdgeCompileConfig] = None,
1379+
generate_etrecord: bool = False,
13791380
) -> "EdgeProgramManager":
13801381
"""
13811382
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
@@ -1388,6 +1389,8 @@ def to_edge(
13881389
13891390
compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
13901391
1392+
generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. Default is False.
1393+
13911394
Returns:
13921395
EdgeProgramManager
13931396
"""
@@ -1441,7 +1444,14 @@ def to_edge(
14411444
logging.info(f"Input program {name} is not in Edge dialect.")
14421445
raise e
14431446

1444-
return EdgeProgramManager(edge_programs, constant_methods, config)
1447+
epm = EdgeProgramManager(edge_programs, constant_methods, config)
1448+
if generate_etrecord:
1449+
etrecord = _create_empty_etrecord()
1450+
etrecord.add_exported_program(aten_programs)
1451+
etrecord.add_edge_dialect_program(copy.deepcopy(epm))
1452+
epm._etrecord = etrecord
1453+
1454+
return epm
14451455

14461456

14471457
class EdgeProgramManager:

0 commit comments

Comments
 (0)