Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
ETRecord,
ETRecordReservedFileNames,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from executorch.exir.program._program import to_edge_transform_and_lower
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
from torch.export import export


Expand Down Expand Up @@ -105,11 +105,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
self.assertIsNotNone(etrecord._debug_handle_map)
self.assertIsNotNone(etrecord._delegate_map)

def get_test_model(self):
def get_test_model(self, generate_etrecord=False):
f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs(), strict=True)
edge_program: EdgeProgramManager = to_edge(
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
aten_dialect,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
generate_etrecord=generate_etrecord,
)
edge_program_copy = copy.deepcopy(edge_program)
return (aten_dialect, edge_program_copy, edge_program.to_executorch())
Expand Down Expand Up @@ -392,6 +394,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self):

self.assertIn("ETRecord was not generated", str(context.exception))

def test_to_edge_with_etrecord_generation(self):
"""Test that to_edge generates ETRecord correctly."""
aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True)

# Verify that ETRecord was generated and attached
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord
self.assert_legal_etrecord_in_edge_program(etrecord)

# Verify the exported program matches the input
self.check_graph_closeness(
etrecord.exported_program,
aten_program.graph_module,
)
self.assertEqual(
etrecord.export_graph_id,
id(aten_program.graph),
)

# Verify the edge dialect program matches the edge manager
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_manager.exported_program().graph_module,
)

def test_to_edge_without_etrecord_generation(self):
"""Test that to_edge works correctly without ETRecord generation."""
# Test with generate_etrecord=False (default)
_, edge_manager, et_manager = self.get_test_model()

# Verify that no ETRecord was generated
self.assertIsNone(edge_manager._etrecord)

# Test get_etrecord method should raise RuntimeError
with self.assertRaises(RuntimeError):
et_manager.get_etrecord()

def test_to_edge_etrecord_save_and_parse(self):
"""Test that ETRecord generated by to_edge can be saved and parsed."""
aten_program, _, et_manager = self.get_test_model(generate_etrecord=True)

etrecord = et_manager.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_to_edge.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
# Note: Skip graph structure comparison due to transformation differences
self.check_graph_closeness(
etrecord.exported_program, parsed_etrecord.exported_program
)
self.check_graph_closeness(
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
)

# Validate executorch program data
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id
self.assertEqual(
parsed_etrecord.export_graph_id,
id(aten_program.graph),
)

def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
f = models.BasicSinMax()
Expand Down
12 changes: 11 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,7 @@ def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
generate_etrecord: bool = False,
) -> "EdgeProgramManager":
"""
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
Expand All @@ -1388,6 +1389,8 @@ def to_edge(

compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.

generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. Default is False.

Returns:
EdgeProgramManager
"""
Expand Down Expand Up @@ -1441,7 +1444,14 @@ def to_edge(
logging.info(f"Input program {name} is not in Edge dialect.")
raise e

return EdgeProgramManager(edge_programs, constant_methods, config)
epm = EdgeProgramManager(edge_programs, constant_methods, config)
if generate_etrecord:
etrecord = _create_empty_etrecord()
etrecord.add_exported_program(aten_programs)
etrecord.add_edge_dialect_program(copy.deepcopy(epm))
epm._etrecord = etrecord

return epm


class EdgeProgramManager:
Expand Down
Loading