diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 014148f2a13..e149aeab650 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -9,14 +9,15 @@ import json import os import pickle -from dataclasses import dataclass from typing import BinaryIO, Dict, IO, List, Optional, Union from zipfile import BadZipFile, ZipFile +import torch + from executorch import exir -from executorch.devtools.bundled_program.core import BundledProgram -from executorch.devtools.bundled_program.schema.bundled_program_schema import Value +from executorch.devtools.bundled_program.config import ConfigValue +from executorch.devtools.bundled_program.core import BundledProgram from executorch.exir import ( EdgeProgramManager, ExecutorchProgram, @@ -29,8 +30,8 @@ from executorch.exir.serde.export_serialize import SerializedArtifact from executorch.exir.serde.serialize import deserialize, serialize -ProgramInput = List[Value] -ProgramOutput = List[Value] +ProgramInput = ConfigValue +ProgramOutput = torch.Tensor try: # breaking change introduced in python 3.11 @@ -55,96 +56,149 @@ class ETRecordReservedFileNames(StrEnum): REPRESENTATIVE_INPUTS = "representative_inputs" -@dataclass class ETRecord: - exported_program: Optional[ExportedProgram] = None - export_graph_id: Optional[int] = None - edge_dialect_program: Optional[ExportedProgram] = None - graph_map: Optional[Dict[str, ExportedProgram]] = None - _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None - _delegate_map: Optional[ - Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] - ] = None - _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None - _representative_inputs: Optional[List[ProgramOutput]] = None - - -def _handle_exported_program( - etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram -) -> None: - assert isinstance(ep, ExportedProgram) - serialized_artifact = serialize(ep) - assert isinstance(serialized_artifact.exported_program, bytes) + def __init__( + self, + exported_program: Optional[ExportedProgram] = None, + export_graph_id: Optional[int] = None, + edge_dialect_program: Optional[ExportedProgram] = None, + graph_map: Optional[Dict[str, ExportedProgram]] = None, + _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None, + _delegate_map: Optional[ + Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] + ] = None, + _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, + _representative_inputs: Optional[List[ProgramOutput]] = None, + ): + self.exported_program = exported_program + self.export_graph_id = export_graph_id + self.edge_dialect_program = edge_dialect_program + self.graph_map = graph_map + self._debug_handle_map = _debug_handle_map + self._delegate_map = _delegate_map + self._reference_outputs = _reference_outputs + self._representative_inputs = _representative_inputs + + def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None: + """ + Serialize and save the ETRecord to the specified path. + + Args: + path: Path where the ETRecord file will be saved to. + """ + if isinstance(path, (str, os.PathLike)): + # pyre-ignore[6]: In call `os.fspath`, for 1st positional argument, expected `str` but got `Union[PathLike[typing.Any], str]` + path = os.fspath(path) + + etrecord_zip = ZipFile(path, "w") + + try: + self._write_identifier(etrecord_zip) + self._save_programs(etrecord_zip) + self._save_graph_map(etrecord_zip) + self._save_metadata(etrecord_zip) + finally: + etrecord_zip.close() + + def _write_identifier(self, etrecord_zip: ZipFile) -> None: + """Write the magic file identifier.""" + etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") + + def _save_programs(self, etrecord_zip: ZipFile) -> None: + """Save exported program and edge dialect program.""" + if self.exported_program is not None: + self._save_exported_program( + etrecord_zip, + ETRecordReservedFileNames.EXPORTED_PROGRAM, + "", + self.exported_program, + ) - method_name = f"/{method_name}" if method_name != "" else "" + if self.edge_dialect_program is not None: + self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program) + + def _save_graph_map(self, etrecord_zip: ZipFile) -> None: + """Save graph map if present.""" + if self.graph_map is not None: + # pyre-ignore[16]: Undefined attribute [16]: `Optional` has no attribute `items`. + for module_name, export_module in self.graph_map.items(): + if "/" in module_name: + base_name, method_name = module_name.rsplit("/", 1) + self._save_exported_program( + etrecord_zip, base_name, method_name, export_module + ) + else: + self._save_exported_program( + etrecord_zip, module_name, "forward", export_module + ) + + def _save_metadata(self, etrecord_zip: ZipFile) -> None: + """Save debug maps, reference outputs, and other metadata.""" + if self._debug_handle_map is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, + json.dumps(self._debug_handle_map), + ) - etrecord_zip.writestr( - f"{module_name}{method_name}", serialized_artifact.exported_program - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_constants", serialized_artifact.constants - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_example_inputs", - serialized_artifact.example_inputs, - ) + if self._delegate_map is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.DELEGATE_MAP_NAME, + json.dumps(self._delegate_map), + ) + if self._reference_outputs is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.REFERENCE_OUTPUTS, + pickle.dumps(self._reference_outputs), + ) -def _handle_export_module( - etrecord_zip: ZipFile, - export_module: Union[ - ExirExportedProgram, - EdgeProgramManager, - ExportedProgram, - ], - module_name: str, -) -> None: - if isinstance(export_module, ExirExportedProgram): - _handle_exported_program( - etrecord_zip, module_name, "forward", export_module.exported_program - ) - elif isinstance(export_module, ExportedProgram): - _handle_exported_program(etrecord_zip, module_name, "forward", export_module) - elif isinstance( - export_module, - (EdgeProgramManager, exir.program._program.EdgeProgramManager), - ): - for method in export_module.methods: - _handle_exported_program( - etrecord_zip, - module_name, - method, - export_module.exported_program(method), + if self._representative_inputs is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, + pickle.dumps(self._representative_inputs), ) - else: - raise RuntimeError(f"Unsupported graph module type. {type(export_module)}") + if self.export_graph_id is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.EXPORT_GRAPH_ID, + json.dumps(self.export_graph_id), + ) -def _handle_edge_dialect_exported_program( - etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram -) -> None: - serialized_artifact = serialize(edge_dialect_exported_program) - assert isinstance(serialized_artifact.exported_program, bytes) + def _save_exported_program( + self, + etrecord_zip: ZipFile, + module_name: str, + method_name: str, + ep: ExportedProgram, + ) -> None: + """Save an exported program to the ETRecord zip file.""" + serialized_artifact = serialize(ep) + assert isinstance(serialized_artifact.exported_program, bytes) + + method_name = f"/{method_name}" if method_name != "" else "" + base_name = f"{module_name}{method_name}" + + etrecord_zip.writestr(base_name, serialized_artifact.exported_program) + etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict) + etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants) + etrecord_zip.writestr( + f"{base_name}_example_inputs", serialized_artifact.example_inputs + ) - etrecord_zip.writestr( - ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM, - serialized_artifact.exported_program, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict", - serialized_artifact.state_dict, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants", - serialized_artifact.constants, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs", - serialized_artifact.example_inputs, - ) + def _save_edge_dialect_program( + self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram + ) -> None: + """Save the edge dialect program to the ETRecord zip file.""" + serialized_artifact = serialize(edge_dialect_program) + assert isinstance(serialized_artifact.exported_program, bytes) + + base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM + etrecord_zip.writestr(base_name, serialized_artifact.exported_program) + etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict) + etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants) + etrecord_zip.writestr( + f"{base_name}_example_inputs", serialized_artifact.example_inputs + ) def _get_reference_outputs( @@ -231,93 +285,141 @@ def generate_etrecord( Returns: None """ + # Process all inputs and prepare data for ETRecord construction + processed_exported_program, export_graph_id = _process_exported_program( + exported_program + ) + graph_map = _process_extra_recorded_modules(extra_recorded_export_modules) + processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program) + debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( + _process_executorch_program(executorch_program) + ) - if isinstance(et_record, (str, os.PathLike)): - et_record = os.fspath(et_record) # pyre-ignore + # Create ETRecord instance and save + etrecord = ETRecord( + exported_program=processed_exported_program, + export_graph_id=export_graph_id, + edge_dialect_program=processed_edge_dialect_program, + graph_map=graph_map if graph_map else None, + _debug_handle_map=debug_handle_map, + _delegate_map=delegate_map, + _reference_outputs=reference_outputs, + _representative_inputs=representative_inputs, + ) + + etrecord.save(et_record) - etrecord_zip = ZipFile(et_record, "w") - # Write the magic file identifier that will be used to verify that this file - # is an etrecord when it's used later in the Developer Tools. - etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") - # Calculate export_graph_id before modifying exported_program +def _process_exported_program( + exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]] +) -> tuple[Optional[ExportedProgram], int]: + """Process exported program and return the processed program and export graph id.""" + processed_exported_program = None export_graph_id = 0 if exported_program is not None: - # If multiple exported programs are provided, only save forward method if isinstance(exported_program, dict) and "forward" in exported_program: - exported_program = exported_program["forward"] + processed_exported_program = exported_program["forward"] + elif isinstance(exported_program, ExportedProgram): + processed_exported_program = exported_program - if isinstance(exported_program, ExportedProgram): - export_graph_id = id(exported_program.graph) - _handle_exported_program( - etrecord_zip, - ETRecordReservedFileNames.EXPORTED_PROGRAM, - "", - exported_program, - ) + if processed_exported_program is not None: + export_graph_id = id(processed_exported_program.graph) + + return processed_exported_program, export_graph_id + + +def _process_extra_recorded_modules( + extra_recorded_export_modules: Optional[ + Dict[ + str, + Union[ + ExportedProgram, + ExirExportedProgram, + EdgeProgramManager, + ], + ] + ] +) -> Dict[str, ExportedProgram]: + """Process extra recorded export modules and return graph map.""" + graph_map = {} if extra_recorded_export_modules is not None: for module_name, export_module in extra_recorded_export_modules.items(): - contains_reserved_name = any( - reserved_name in module_name - for reserved_name in ETRecordReservedFileNames + _validate_module_name(module_name) + _add_module_to_graph_map(graph_map, module_name, export_module) + + return graph_map + + +def _validate_module_name(module_name: str) -> None: + """Validate that module name is not a reserved name.""" + contains_reserved_name = any( + reserved_name in module_name for reserved_name in ETRecordReservedFileNames + ) + if contains_reserved_name: + raise RuntimeError( + f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace." + ) + + +def _add_module_to_graph_map( + graph_map: Dict[str, ExportedProgram], + module_name: str, + export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager], +) -> None: + """Add export module to graph map based on its type.""" + if isinstance(export_module, ExirExportedProgram): + graph_map[f"{module_name}/forward"] = export_module.exported_program + elif isinstance(export_module, ExportedProgram): + graph_map[f"{module_name}/forward"] = export_module + elif isinstance( + export_module, + (EdgeProgramManager, exir.program._program.EdgeProgramManager), + ): + for method in export_module.methods: + graph_map[f"{module_name}/{method}"] = export_module.exported_program( + method ) - if contains_reserved_name: - raise RuntimeError( - f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace." - ) - _handle_export_module(etrecord_zip, export_module, module_name) + else: + raise RuntimeError(f"Unsupported graph module type. {type(export_module)}") + +def _process_edge_dialect_program( + edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram] +) -> ExportedProgram: + """Process edge dialect program and return the exported program.""" if isinstance( edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager), ): - _handle_edge_dialect_exported_program( - etrecord_zip, - edge_dialect_program.exported_program(), - ) + return edge_dialect_program.exported_program() elif isinstance(edge_dialect_program, ExirExportedProgram): - _handle_edge_dialect_exported_program( - etrecord_zip, - edge_dialect_program.exported_program, - ) + return edge_dialect_program.exported_program else: raise RuntimeError( f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." ) - # When a BundledProgram is passed in, extract the reference outputs and save in a file + +def _process_executorch_program( + executorch_program: Union[ + ExecutorchProgram, ExecutorchProgramManager, BundledProgram + ] +) -> tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[List]]: + """Process executorch program and return debug maps and bundled program data.""" if isinstance(executorch_program, BundledProgram): reference_outputs = _get_reference_outputs(executorch_program) - etrecord_zip.writestr( - ETRecordReservedFileNames.REFERENCE_OUTPUTS, - # @lint-ignore PYTHONPICKLEISBAD - pickle.dumps(reference_outputs), - ) - representative_inputs = _get_representative_inputs(executorch_program) - etrecord_zip.writestr( - ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, - # @lint-ignore PYTHONPICKLEISBAD - pickle.dumps(representative_inputs), - ) - executorch_program = executorch_program.executorch_program - - etrecord_zip.writestr( - ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, - json.dumps(executorch_program.debug_handle_map), - ) - - etrecord_zip.writestr( - ETRecordReservedFileNames.DELEGATE_MAP_NAME, - json.dumps(executorch_program.delegate_map), - ) - - etrecord_zip.writestr( - ETRecordReservedFileNames.EXPORT_GRAPH_ID, - json.dumps(export_graph_id), - ) + # pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `debug_handle_map` + debug_handle_map = executorch_program.executorch_program.debug_handle_map + # pyre-ignore[16]: Item `None` of `typing.Union[None, exir.program._program.ExecutorchProgram, exir.program._program.ExecutorchProgramManager]` has no attribute `debug_handle_map` + delegate_map = executorch_program.executorch_program.delegate_map + return debug_handle_map, delegate_map, reference_outputs, representative_inputs + else: + debug_handle_map = executorch_program.debug_handle_map + delegate_map = executorch_program.delegate_map + return debug_handle_map, delegate_map, None, None def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 432397347a5..9b9f3290162 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -20,6 +20,7 @@ from executorch.devtools.etrecord._etrecord import ( _get_reference_outputs, _get_representative_inputs, + ETRecord, ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge @@ -251,6 +252,122 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_etrecord_class_constructor_and_save(self): + """Test that ETRecord class constructor and save method work correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + expected_graph_id = id(original_exported_program.graph) + + # Create ETRecord instance directly using constructor + etrecord = ETRecord( + exported_program=original_exported_program, + export_graph_id=expected_graph_id, + edge_dialect_program=edge_output.exported_program, + graph_map={"test_module/forward": original_exported_program}, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_direct.bin" + + # Use the save method + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + original_exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate graph map + self.assertIsNotNone(parsed_etrecord.graph_map) + self.assertIn("test_module/forward", parsed_etrecord.graph_map) + self.check_graph_closeness( + parsed_etrecord.graph_map["test_module/forward"], + original_exported_program.graph_module, + ) + + # Validate debug and delegate maps + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Validate export graph id + self.assertEqual(parsed_etrecord.export_graph_id, expected_graph_id) + + def test_etrecord_class_with_bundled_program_data(self): + """Test ETRecord class with bundled program data.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Extract bundled program data + reference_outputs = _get_reference_outputs(bundled_program) + representative_inputs = _get_representative_inputs(bundled_program) + + # Create ETRecord instance with bundled program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + _reference_outputs=reference_outputs, + _representative_inputs=representative_inputs, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_bundled.bin" + + # Save using the save method + etrecord.save(etrecord_path) + + # Parse and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate bundled program specific data + self.assertIsNotNone(parsed_etrecord._reference_outputs) + self.assertIsNotNone(parsed_etrecord._representative_inputs) + + # Compare reference outputs + expected_outputs = parsed_etrecord._reference_outputs + self.assertTrue( + torch.equal( + expected_outputs["forward"][0][0], + reference_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + expected_outputs["forward"][1][0], + reference_outputs["forward"][1][0], + ) + ) + + # Compare representative inputs + expected_inputs = parsed_etrecord._representative_inputs + for expected, actual in zip(expected_inputs, representative_inputs): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + def test_etrecord_generation_with_exported_program_dict(self): """Test that exported program dictionary can be recorded and parsed back correctly.""" captured_output, edge_output, et_output = self.get_test_model()