From 499d7e9858b941a43cd3d3ec24b48dc6809276eb Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 8 Jul 2025 15:07:08 -0700 Subject: [PATCH] make etrecord support export program Differential Revision: [D77965102](https://our.internmc.facebook.com/intern/diff/D77965102/) [ghstack-poisoned] --- devtools/etrecord/_etrecord.py | 52 ++++++++++--- devtools/etrecord/tests/etrecord_test.py | 78 ++++++++++++++++++- devtools/inspector/tests/inspector_test.py | 2 +- .../inspector/tests/inspector_utils_test.py | 2 +- .../devtools/scripts/gen_sample_etrecord.py | 2 +- 5 files changed, 120 insertions(+), 16 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index ffb81a8e41a..3a774db5854 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -45,6 +45,7 @@ class StrEnum(str, Enum): class ETRecordReservedFileNames(StrEnum): ETRECORD_IDENTIFIER = "ETRECORD_V0" + EXPORTED_PROGRAM = "exported_program" EDGE_DIALECT_EXPORTED_PROGRAM = "edge_dialect_exported_program" ET_DIALECT_GRAPH_MODULE = "et_dialect_graph_module" DEBUG_HANDLE_MAP_NAME = "debug_handle_map" @@ -55,6 +56,7 @@ class ETRecordReservedFileNames(StrEnum): @dataclass class ETRecord: + exported_program: Optional[ExportedProgram] = None edge_dialect_program: Optional[ExportedProgram] = None graph_map: Optional[Dict[str, ExportedProgram]] = None _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None @@ -71,17 +73,20 @@ def _handle_exported_program( assert isinstance(ep, ExportedProgram) serialized_artifact = serialize(ep) assert isinstance(serialized_artifact.exported_program, bytes) + + method_name = f"/{method_name}" if method_name != "" else "" + etrecord_zip.writestr( - f"{module_name}/{method_name}", serialized_artifact.exported_program + f"{module_name}{method_name}", serialized_artifact.exported_program ) etrecord_zip.writestr( - f"{module_name}/{method_name}_state_dict", serialized_artifact.state_dict + f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict ) etrecord_zip.writestr( - f"{module_name}/{method_name}_constants", serialized_artifact.constants + f"{module_name}{method_name}_constants", serialized_artifact.constants ) etrecord_zip.writestr( - f"{module_name}/{method_name}_example_inputs", + f"{module_name}{method_name}_example_inputs", serialized_artifact.example_inputs, ) @@ -188,7 +193,10 @@ def generate_etrecord( ExecutorchProgramManager, BundledProgram, ], - export_modules: Optional[ + exported_program: Optional[ + Union[ExportedProgram, Dict[str, ExportedProgram]] + ] = None, + extra_recorded_export_modules: Optional[ Dict[ str, Union[ @@ -202,7 +210,7 @@ def generate_etrecord( """ Generates an `ETRecord` from the given objects, serializes it and saves it to the given path. The objects that will be serialized to an `ETRecord` are all the graph modules present - in the `export_modules` dict, the graph module present in the edge dialect program object, + in the `extra_recorded_export_modules` dict, the graph module present in the edge dialect program object, and also the graph module present in the ExecuTorch program object, which is the closest graph module representation of what is eventually run on the device. In addition to all the graph modules, we also serialize the program buffer, which the users @@ -213,7 +221,8 @@ def generate_etrecord( et_record: Path to where the `ETRecord` file will be saved to. edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge() executorch_program: The ExecuTorch program for this model returned by the call to `to_executorch()` or the `BundledProgram` of this model - export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the + exported_program: Optional graph module for this model returned by the call to `torch.export` from nn.Module. + extra_recorded_export_modules [Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the value being the corresponding exported module. The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. @@ -229,15 +238,28 @@ def generate_etrecord( # is an etrecord when it's used later in the Developer Tools. etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") - if export_modules is not None: - for module_name, export_module in export_modules.items(): + if exported_program is not None: + # If multiple exported programs are provided, only saved forward method + if isinstance(exported_program, dict) and "forward" in exported_program: + exported_program = exported_program["forward"] + + if isinstance(exported_program, ExportedProgram): + _handle_exported_program( + etrecord_zip, + ETRecordReservedFileNames.EXPORTED_PROGRAM, + "", + exported_program, + ) + + if extra_recorded_export_modules is not None: + for module_name, export_module in extra_recorded_export_modules.items(): contains_reserved_name = any( reserved_name in module_name for reserved_name in ETRecordReservedFileNames ) if contains_reserved_name: raise RuntimeError( - f"The name {module_name} provided in the export_modules dict is a reserved name in the ETRecord namespace." + f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace." ) _handle_export_module(etrecord_zip, export_module, module_name) @@ -318,6 +340,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map: Dict[str, ExportedProgram] = {} debug_handle_map = None delegate_map = None + exported_program = None edge_dialect_program = None reference_outputs = None representative_inputs = None @@ -347,6 +370,14 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 etrecord_zip.read(f"{entry}_example_inputs"), ) edge_dialect_program = deserialize(serialized_artifact) + elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM: + serialized_artifact = SerializedArtifact( + etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM), + etrecord_zip.read(f"{entry}_state_dict"), + etrecord_zip.read(f"{entry}_constants"), + etrecord_zip.read(f"{entry}_example_inputs"), + ) + exported_program = deserialize(serialized_artifact) elif entry == ETRecordReservedFileNames.REFERENCE_OUTPUTS: # @lint-ignore PYTHONPICKLEISBAD reference_outputs = pickle.loads( @@ -383,6 +414,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 graph_map[serialized_file] = deserialize(serialized_artifact) return ETRecord( + exported_program=exported_program, edge_dialect_program=edge_dialect_program, graph_map=graph_map, _debug_handle_map=debug_handle_map, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index dd1d40e0292..1d7efdedf2e 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -100,12 +100,13 @@ def test_etrecord_generation(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + self.check_graph_closeness( etrecord.graph_map["aten_dialect_output/forward"], captured_output.exported_program.graph_module, @@ -184,7 +185,7 @@ def test_etrecord_invalid_input(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - {"fail_test_case": et_output}, + extra_recorded_export_modules={"fail_test_case": et_output}, ) def test_etrecord_reserved_name(self): @@ -196,5 +197,76 @@ def test_etrecord_reserved_name(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - {reserved_name: captured_output.exported_program.graph_module}, + extra_recorded_export_modules={ + reserved_name: captured_output.exported_program.graph_module + }, ) + + def test_etrecord_generation_with_exported_program(self): + """Test that exported program can be recorded and parsed back correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + + with tempfile.TemporaryDirectory() as tmpdirname: + # Generate ETRecord with exported program + generate_etrecord( + tmpdirname + "/etrecord.bin", + edge_output, + et_output, + exported_program=original_exported_program, + ) + + # Parse ETRecord back + etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + + # Validate that the parsed exported program matches the original + self.assertIsNotNone(etrecord.exported_program) + self.check_graph_closeness( + etrecord.exported_program, + original_exported_program.graph_module, + ) + + # Validate other components are still present + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + + def test_etrecord_generation_with_exported_program_dict(self): + """Test that exported program dictionary can be recorded and parsed back correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + exported_program_dict = {"forward": original_exported_program} + + with tempfile.TemporaryDirectory() as tmpdirname: + # Generate ETRecord with exported program dictionary + generate_etrecord( + tmpdirname + "/etrecord.bin", + edge_output, + et_output, + exported_program=exported_program_dict, + ) + + # Parse ETRecord back + etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") + + # Validate that the parsed exported program matches the original + self.assertIsNotNone(etrecord.exported_program) + self.check_graph_closeness( + etrecord.exported_program, + original_exported_program.graph_module, + ) + + # Validate other components are still present + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 7c294d81571..7e99bface2c 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -327,7 +327,7 @@ def test_inspector_get_exported_program(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index d7707ffa199..6a399dc41c4 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -52,7 +52,7 @@ def test_gen_graphs_from_etrecord(self): tmpdirname + "/etrecord.bin", edge_output, et_output, - { + extra_recorded_export_modules={ "aten_dialect_output": captured_output, }, ) diff --git a/examples/devtools/scripts/gen_sample_etrecord.py b/examples/devtools/scripts/gen_sample_etrecord.py index a6b3d487251..e5b46cdede5 100644 --- a/examples/devtools/scripts/gen_sample_etrecord.py +++ b/examples/devtools/scripts/gen_sample_etrecord.py @@ -41,7 +41,7 @@ def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None): (DEFAULT_OUTPUT_PATH if not output_path else output_path), edge_dialect_program=edge_program, executorch_program=et_program, - export_modules={ + extra_recorded_export_modules={ "aten_dialect_output": aten_dialect, }, )