Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devtools/etrecord/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ python_unittest(
"//executorch/devtools/etrecord:etrecord",
"//executorch/exir:lib",
"//executorch/exir/tests:models",
"//executorch/export:lib",
],
)

Expand Down
136 changes: 136 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower

from executorch.export import export as etexport, ExportRecipe, StageType
from torch.export import export


Expand Down Expand Up @@ -136,6 +138,33 @@ def get_test_model_with_bundled_program(self):
bundled_program = BundledProgram(et_output, method_test_suites)
return (aten_dialect, edge_program_copy, bundled_program)

def get_test_export_session(self, generate_etrecord=False, to_edge_flow=False):
f = models.BasicSinMax()
example_inputs = [f.get_random_inputs()]
export_recipe = None

if to_edge_flow:
export_recipe = ExportRecipe(
pipeline_stages=[
StageType.TORCH_EXPORT,
StageType.TO_EDGE,
StageType.TO_BACKEND,
StageType.TO_EXECUTORCH,
]
)
else:
export_recipe = ExportRecipe()

# Test with generate_etrecord=True
export_session = etexport(
model=f,
example_inputs=example_inputs,
export_recipe=export_recipe,
generate_etrecord=generate_etrecord,
)

return export_session

# Serialized and deserialized graph modules are not completely the same, so we check
# that they are close enough and match especially on the parameters we care about in the Developer Tools.
def check_graph_closeness(self, graph_a, graph_b):
Expand Down Expand Up @@ -1261,6 +1290,113 @@ def test_add_all_programs_sequentially(self):
json.loads(json.dumps(et_output.delegate_map)),
)

def test_executorch_export_with_etrecord_generation(self):
"""Test that executorch.export generates ETRecord correctly when generate_etrecord=True."""
# Verify that ETRecord was generated and can be retrieved
export_session = self.get_test_export_session(generate_etrecord=True)
etrecord = export_session.get_etrecord()
self.assertIsNotNone(etrecord)
self.assert_etrecord_saveable(etrecord)

# Verify the executorch program data matches
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)

def test_executorch_export_without_etrecord_generation(self):
"""Test that executorch.export works correctly without ETRecord generation."""
# Test with generate_etrecord=False (default)
export_session = self.get_test_export_session(generate_etrecord=False)

# Verify that no ETRecord was generated
with self.assertRaises(RuntimeError) as context:
export_session.get_etrecord()

self.assertIn("ETRecord was not generated", str(context.exception))

# Verify that the export session still works correctly
self.assertIsNotNone(export_session.get_executorch_program_manager())
self.assertTrue(len(export_session.get_pte_buffer()) > 0)

def test_executorch_export_etrecord_save_and_parse(self):
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
export_session = self.get_test_export_session(generate_etrecord=True)

etrecord = export_session.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_export.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
self.assertIsNotNone(parsed_etrecord.exported_program)
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)

# Validate executorch program data
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id is preserved
self.assertIsNotNone(parsed_etrecord.export_graph_id)

def test_executorch_export_with_to_edge_flow(self):
"""Test executorch.export with TO_EDGE flow and ETRecord generation."""
export_session = self.get_test_export_session(
generate_etrecord=True,
to_edge_flow=True,
)

# Verify that ETRecord was generated
etrecord = export_session.get_etrecord()
self.assertIsNotNone(etrecord)
self.assert_etrecord_saveable(etrecord)

def test_executorch_export_etrecord_with_to_edge_flow_save_and_parse(self):
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
export_session = self.get_test_export_session(
generate_etrecord=True,
to_edge_flow=True,
)

etrecord = export_session.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_export.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
self.assertIsNotNone(parsed_etrecord.exported_program)
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)

# Validate executorch program data
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id is preserved
self.assertIsNotNone(parsed_etrecord.export_graph_id)

def test_update_representative_inputs_with_list(self):
"""Test update_representative_inputs with a list of ProgramInput objects."""
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
19 changes: 19 additions & 0 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def export(
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
artifact_dir: Optional[str] = None,
generate_etrecord: Optional[bool] = False,
) -> "ExportSession":
"""
Create and configure an ExportSession with the given parameters.
Expand All @@ -61,6 +62,7 @@ def export(
dynamic_shapes: Optional dynamic shape specifications
constant_methods: Optional dictionary of constant methods
artifact_dir: Optional directory to store artifacts
generate_etrecord: Optional flag to generate an etrecord

Returns:
A configured ExportSession instance with the export process completed if requested
Expand All @@ -73,6 +75,7 @@ def export(
dynamic_shapes=dynamic_shapes,
constant_methods=constant_methods,
artifact_dir=artifact_dir,
generate_etrecord=generate_etrecord,
)
session.export()

Expand Down Expand Up @@ -104,6 +107,7 @@ def __init__(
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
artifact_dir: Optional[str] = None,
generate_etrecord: Optional[bool] = False,
) -> None:
"""
Initialize the ExportSession with model, inputs, and recipe.
Expand All @@ -118,6 +122,7 @@ def __init__(
dynamic_shapes: Optional dynamic shape specifications
constant_methods: Optional dictionary of constant methods
artifact_dir: Optional directory to store artifacts
generate_etrecord: Optional flag to generate an etrecord
"""
# Standardize model to dictionary format
self._model = model if isinstance(model, dict) else {"forward": model}
Expand Down Expand Up @@ -165,6 +170,7 @@ def __init__(
"export_recipe": self._export_recipe,
"session_name": name,
"artifact_dir": artifact_dir,
"generate_etrecord": generate_etrecord,
}

self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}
Expand Down Expand Up @@ -453,3 +459,16 @@ def print_delegation_info(self) -> None:
logging.info(tabulate(df, headers="keys", tablefmt="fancy_grid"))
else:
logging.info("No delegation info available")

# Use Any instead of ETRecord as return type to avoid static dependency on etrecord
def get_etrecord(self) -> Any:
"""
Get the etrecord from the ExecuTorchProgramManager.

Returns:
The etrecord in the ExecuTorchProgramManager

Raises:
RuntimeError: If the ExecuTorchManager is unavailable, or etrecord is not available in the ExecuTorchProgramManager
"""
return self.get_executorch_program_manager().get_etrecord()
3 changes: 3 additions & 0 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def run(self, artifact: PipelineArtifact) -> None:
"""
exported_programs = artifact.data
constant_methods = artifact.get_context("constant_methods")
generate_etrecord = artifact.get_context("generate_etrecord", False)

with validation_disabled():
edge_program_manager = to_edge_transform_and_lower(
Expand All @@ -207,6 +208,7 @@ def run(self, artifact: PipelineArtifact) -> None:
transform_passes=self._transform_passes,
constant_methods=constant_methods,
compile_config=self._compile_config,
generate_etrecord=generate_etrecord,
)

delegation_info = get_delegation_info(
Expand Down Expand Up @@ -418,6 +420,7 @@ def run(self, artifact: PipelineArtifact) -> None:
exported_programs,
constant_methods=constant_methods,
compile_config=self._edge_compile_config,
generate_etrecord=artifact.get_context("generate_etrecord", False),
)

self._artifact = artifact.copy_with_new_data(edge_program_manager)
Expand Down
Loading