Skip to content

Commit eed1412

Browse files
committed
make to_edge support etrecord generation
This support to_edge export flow etrecord generation supportive. Details can be found in #12925 Differential Revision: [D79707919](https://our.internmc.facebook.com/intern/diff/D79707919/) [ghstack-poisoned]
1 parent 250aeeb commit eed1412

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
ETRecord,
2525
ETRecordReservedFileNames,
2626
)
27-
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
28-
from executorch.exir.program._program import to_edge_transform_and_lower
27+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
28+
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
2929
from torch.export import export
3030

3131

@@ -105,11 +105,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
105105
self.assertIsNotNone(etrecord._debug_handle_map)
106106
self.assertIsNotNone(etrecord._delegate_map)
107107

108-
def get_test_model(self):
108+
def get_test_model(self, generate_etrecord=False):
109109
f = models.BasicSinMax()
110110
aten_dialect = export(f, f.get_random_inputs(), strict=True)
111111
edge_program: EdgeProgramManager = to_edge(
112-
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
112+
aten_dialect,
113+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
114+
generate_etrecord=generate_etrecord,
113115
)
114116
edge_program_copy = copy.deepcopy(edge_program)
115117
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
@@ -392,6 +394,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self):
392394

393395
self.assertIn("ETRecord was not generated", str(context.exception))
394396

397+
def test_to_edge_with_etrecord_generation(self):
398+
"""Test that to_edge generates ETRecord correctly."""
399+
aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True)
400+
401+
# Verify that ETRecord was generated and attached
402+
self.assertIsNotNone(edge_manager._etrecord)
403+
etrecord = edge_manager._etrecord
404+
self.assert_legal_etrecord_in_edge_program(etrecord)
405+
406+
# Verify the exported program matches the input
407+
self.check_graph_closeness(
408+
etrecord.exported_program,
409+
aten_program.graph_module,
410+
)
411+
self.assertEqual(
412+
etrecord.export_graph_id,
413+
id(aten_program.graph),
414+
)
415+
416+
# Verify the edge dialect program matches the edge manager
417+
self.check_graph_closeness(
418+
etrecord.edge_dialect_program,
419+
edge_manager.exported_program().graph_module,
420+
)
421+
422+
def test_to_edge_without_etrecord_generation(self):
423+
"""Test that to_edge works correctly without ETRecord generation."""
424+
# Test with generate_etrecord=False (default)
425+
_, edge_manager, et_manager = self.get_test_model()
426+
427+
# Verify that no ETRecord was generated
428+
self.assertIsNone(edge_manager._etrecord)
429+
430+
# Test get_etrecord method should raise RuntimeError
431+
with self.assertRaises(RuntimeError):
432+
et_manager.get_etrecord()
433+
434+
def test_to_edge_etrecord_save_and_parse(self):
435+
"""Test that ETRecord generated by to_edge can be saved and parsed."""
436+
aten_program, _, et_manager = self.get_test_model(generate_etrecord=True)
437+
438+
etrecord = et_manager.get_etrecord()
439+
440+
with tempfile.TemporaryDirectory() as tmpdirname:
441+
etrecord_path = tmpdirname + "/etrecord_to_edge.bin"
442+
443+
etrecord.save(etrecord_path)
444+
445+
# Parse ETRecord back and verify
446+
parsed_etrecord = parse_etrecord(etrecord_path)
447+
448+
# Validate that all components are preserved
449+
# Note: Skip graph structure comparison due to transformation differences
450+
self.check_graph_closeness(
451+
etrecord.exported_program, parsed_etrecord.exported_program
452+
)
453+
self.check_graph_closeness(
454+
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
455+
)
456+
457+
# Validate executorch program data
458+
self.assertEqual(
459+
parsed_etrecord._debug_handle_map,
460+
json.loads(json.dumps(et_manager.debug_handle_map)),
461+
)
462+
self.assertEqual(
463+
parsed_etrecord._delegate_map,
464+
json.loads(json.dumps(et_manager.delegate_map)),
465+
)
466+
467+
# Validate export graph id
468+
self.assertEqual(
469+
parsed_etrecord.export_graph_id,
470+
id(aten_program.graph),
471+
)
472+
395473
def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
396474
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
397475
f = models.BasicSinMax()

exir/program/_program.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,7 @@ def to_edge(
13511351
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
13521352
constant_methods: Optional[Dict[str, Any]] = None,
13531353
compile_config: Optional[EdgeCompileConfig] = None,
1354+
generate_etrecord: Optional[bool] = False,
13541355
) -> "EdgeProgramManager":
13551356
"""
13561357
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
@@ -1363,6 +1364,8 @@ def to_edge(
13631364
13641365
compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
13651366
1367+
generate_etrecord: An optional argument used to generate an etrecord for debugging purposes.
1368+
13661369
Returns:
13671370
EdgeProgramManager
13681371
"""
@@ -1416,7 +1419,14 @@ def to_edge(
14161419
logging.info(f"Input program {name} is not in Edge dialect.")
14171420
raise e
14181421

1419-
return EdgeProgramManager(edge_programs, constant_methods, config)
1422+
epm = EdgeProgramManager(edge_programs, constant_methods, config)
1423+
if generate_etrecord:
1424+
etrecord = _create_empty_etrecord()
1425+
etrecord.add_exported_program(aten_programs)
1426+
etrecord.add_edge_dialect_program(copy.deepcopy(epm))
1427+
epm._etrecord = etrecord
1428+
1429+
return epm
14201430

14211431

14221432
class EdgeProgramManager:

0 commit comments

Comments
 (0)