Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions devtools/etrecord/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ python_unittest(
name = "etrecord_test",
srcs = ["etrecord_test.py"],
deps = [
"//caffe2:torch",
"//executorch/devtools/bundled_program:config",
"//executorch/devtools/bundled_program:core",
"//executorch/devtools/etrecord:etrecord",
"//executorch/exir:lib",
"//executorch/exir/tests:models",
":etrecord_test_library"
],
)

Expand All @@ -26,5 +21,6 @@ python_library(
"//executorch/devtools/etrecord:etrecord",
"//executorch/exir:lib",
"//executorch/exir/tests:models",
"//executorch/export:lib",
],
)
136 changes: 136 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower

from executorch.export import export as etexport, ExportRecipe, StageType
from torch.export import export


Expand Down Expand Up @@ -136,6 +138,33 @@ def get_test_model_with_bundled_program(self):
bundled_program = BundledProgram(et_output, method_test_suites)
return (aten_dialect, edge_program_copy, bundled_program)

def get_test_export_session(self, generate_etrecord=False, to_edge_flow=False):
f = models.BasicSinMax()
example_inputs = [f.get_random_inputs()]
export_recipe = None

if to_edge_flow:
export_recipe = ExportRecipe(
pipeline_stages=[
StageType.TORCH_EXPORT,
StageType.TO_EDGE,
StageType.TO_BACKEND,
StageType.TO_EXECUTORCH,
]
)
else:
export_recipe = ExportRecipe()

# Test with generate_etrecord=True
export_session = etexport(
model=f,
example_inputs=example_inputs,
export_recipe=export_recipe,
generate_etrecord=generate_etrecord,
)

return export_session

# Serialized and deserialized graph modules are not completely the same, so we check
# that they are close enough and match especially on the parameters we care about in the Developer Tools.
def check_graph_closeness(self, graph_a, graph_b):
Expand Down Expand Up @@ -1261,6 +1290,113 @@ def test_add_all_programs_sequentially(self):
json.loads(json.dumps(et_output.delegate_map)),
)

def test_executorch_export_with_etrecord_generation(self):
"""Test that executorch.export generates ETRecord correctly when generate_etrecord=True."""
# Verify that ETRecord was generated and can be retrieved
export_session = self.get_test_export_session(generate_etrecord=True)
etrecord = export_session.get_etrecord()
self.assertIsNotNone(etrecord)
self.assert_etrecord_saveable(etrecord)

# Verify the executorch program data matches
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)

def test_executorch_export_without_etrecord_generation(self):
"""Test that executorch.export works correctly without ETRecord generation."""
# Test with generate_etrecord=False (default)
export_session = self.get_test_export_session(generate_etrecord=False)

# Verify that no ETRecord was generated
with self.assertRaises(RuntimeError) as context:
export_session.get_etrecord()

self.assertIn("ETRecord was not generated", str(context.exception))

# Verify that the export session still works correctly
self.assertIsNotNone(export_session.get_executorch_program_manager())
self.assertTrue(len(export_session.get_pte_buffer()) > 0)

def test_executorch_export_etrecord_save_and_parse(self):
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
export_session = self.get_test_export_session(generate_etrecord=True)

etrecord = export_session.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_export.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
self.assertIsNotNone(parsed_etrecord.exported_program)
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)

# Validate executorch program data
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id is preserved
self.assertIsNotNone(parsed_etrecord.export_graph_id)

def test_executorch_export_with_to_edge_flow(self):
"""Test executorch.export with TO_EDGE flow and ETRecord generation."""
export_session = self.get_test_export_session(
generate_etrecord=True,
to_edge_flow=True,
)

# Verify that ETRecord was generated
etrecord = export_session.get_etrecord()
self.assertIsNotNone(etrecord)
self.assert_etrecord_saveable(etrecord)

def test_executorch_export_etrecord_with_to_edge_flow_save_and_parse(self):
"""Test that ETRecord generated by executorch.export can be saved and parsed."""
export_session = self.get_test_export_session(
generate_etrecord=True,
to_edge_flow=True,
)

etrecord = export_session.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_export.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
self.assertIsNotNone(parsed_etrecord.exported_program)
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)

# Validate executorch program data
et_manager = export_session.get_executorch_program_manager()
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id is preserved
self.assertIsNotNone(parsed_etrecord.export_graph_id)

def test_update_representative_inputs_with_list(self):
"""Test update_representative_inputs with a list of ProgramInput objects."""
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
19 changes: 19 additions & 0 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def export(
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
artifact_dir: Optional[str] = None,
generate_etrecord: Optional[bool] = False,
) -> "ExportSession":
"""
Create and configure an ExportSession with the given parameters.
Expand All @@ -61,6 +62,7 @@ def export(
dynamic_shapes: Optional dynamic shape specifications
constant_methods: Optional dictionary of constant methods
artifact_dir: Optional directory to store artifacts
generate_etrecord: Optional flag to generate an etrecord

Returns:
A configured ExportSession instance with the export process completed if requested
Expand All @@ -73,6 +75,7 @@ def export(
dynamic_shapes=dynamic_shapes,
constant_methods=constant_methods,
artifact_dir=artifact_dir,
generate_etrecord=generate_etrecord,
)
session.export()

Expand Down Expand Up @@ -104,6 +107,7 @@ def __init__(
dynamic_shapes: Optional[Union[Any, Dict[str, Any]]] = None,
constant_methods: Optional[Union[Dict[str, Callable]]] = None,
artifact_dir: Optional[str] = None,
generate_etrecord: Optional[bool] = False,
) -> None:
"""
Initialize the ExportSession with model, inputs, and recipe.
Expand All @@ -118,6 +122,7 @@ def __init__(
dynamic_shapes: Optional dynamic shape specifications
constant_methods: Optional dictionary of constant methods
artifact_dir: Optional directory to store artifacts
generate_etrecord: Optional flag to generate an etrecord
"""
# Standardize model to dictionary format
self._model = model if isinstance(model, dict) else {"forward": model}
Expand Down Expand Up @@ -165,6 +170,7 @@ def __init__(
"export_recipe": self._export_recipe,
"session_name": name,
"artifact_dir": artifact_dir,
"generate_etrecord": generate_etrecord,
}

self._stage_to_artifacts: Dict[StageType, PipelineArtifact] = {}
Expand Down Expand Up @@ -453,3 +459,16 @@ def print_delegation_info(self) -> None:
logging.info(tabulate(df, headers="keys", tablefmt="fancy_grid"))
else:
logging.info("No delegation info available")

# Use Any instead of ETRecord as return type to avoid static dependency on etrecord
def get_etrecord(self) -> Any:
"""
Get the etrecord from the ExecuTorchProgramManager.

Returns:
The etrecord in the ExecuTorchProgramManager

Raises:
RuntimeError: If the ExecuTorchManager is unavailable, or etrecord is not available in the ExecuTorchProgramManager
"""
return self.get_executorch_program_manager().get_etrecord()
3 changes: 3 additions & 0 deletions export/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def run(self, artifact: PipelineArtifact) -> None:
"""
exported_programs = artifact.data
constant_methods = artifact.get_context("constant_methods")
generate_etrecord = artifact.get_context("generate_etrecord", False)

with validation_disabled():
edge_program_manager = to_edge_transform_and_lower(
Expand All @@ -207,6 +208,7 @@ def run(self, artifact: PipelineArtifact) -> None:
transform_passes=self._transform_passes,
constant_methods=constant_methods,
compile_config=self._compile_config,
generate_etrecord=generate_etrecord,
)

delegation_info = get_delegation_info(
Expand Down Expand Up @@ -418,6 +420,7 @@ def run(self, artifact: PipelineArtifact) -> None:
exported_programs,
constant_methods=constant_methods,
compile_config=self._edge_compile_config,
generate_etrecord=artifact.get_context("generate_etrecord", False),
)

self._artifact = artifact.copy_with_new_data(edge_program_manager)
Expand Down
1 change: 1 addition & 0 deletions export/tests/test_export_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_context_propagation_through_pipeline(self) -> None:
"export_recipe",
"session_name",
"artifact_dir",
"generate_etrecord",
}
self.assertEqual(set(session._run_context.keys()), expected_context_keys)
self.assertEqual(session._run_context["session_name"], "test_session")
Expand Down
1 change: 1 addition & 0 deletions export/tests/test_export_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_run_success(self, mock_to_edge: Mock) -> None:
self.exported_programs,
constant_methods=None,
compile_config=mock_config,
generate_etrecord=False,
)

# Verify artifacts are set correctly
Expand Down
Loading