From eed14129353fc3d41d59daedb64c4e2a98bb4cad Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 6 Aug 2025 00:34:29 -0700 Subject: [PATCH] make to_edge support etrecord generation This support to_edge export flow etrecord generation supportive. Details can be found in https://github.com/pytorch/executorch/discussions/12925 Differential Revision: [D79707919](https://our.internmc.facebook.com/intern/diff/D79707919/) [ghstack-poisoned] --- devtools/etrecord/tests/etrecord_test.py | 86 ++++++++++++++++++++++-- exir/program/_program.py | 12 +++- 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index b3b47c679ee..5658026a8e8 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -24,8 +24,8 @@ ETRecord, ETRecordReservedFileNames, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from executorch.exir.program._program import to_edge_transform_and_lower +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.program._program import to_edge, to_edge_transform_and_lower from torch.export import export @@ -105,11 +105,13 @@ def assert_etrecord_saveable(self, etrecord: ETRecord) -> None: self.assertIsNotNone(etrecord._debug_handle_map) self.assertIsNotNone(etrecord._delegate_map) - def get_test_model(self): + def get_test_model(self, generate_etrecord=False): f = models.BasicSinMax() aten_dialect = export(f, f.get_random_inputs(), strict=True) edge_program: EdgeProgramManager = to_edge( - aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False) + aten_dialect, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + generate_etrecord=generate_etrecord, ) edge_program_copy = copy.deepcopy(edge_program) return (aten_dialect, edge_program_copy, edge_program.to_executorch()) @@ -392,6 +394,82 @@ def test_get_etrecord_from_executorch_program_manager_without_generation(self): self.assertIn("ETRecord was not generated", str(context.exception)) + def test_to_edge_with_etrecord_generation(self): + """Test that to_edge generates ETRecord correctly.""" + aten_program, edge_manager, _ = self.get_test_model(generate_etrecord=True) + + # Verify that ETRecord was generated and attached + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + self.assert_legal_etrecord_in_edge_program(etrecord) + + # Verify the exported program matches the input + self.check_graph_closeness( + etrecord.exported_program, + aten_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(aten_program.graph), + ) + + # Verify the edge dialect program matches the edge manager + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_manager.exported_program().graph_module, + ) + + def test_to_edge_without_etrecord_generation(self): + """Test that to_edge works correctly without ETRecord generation.""" + # Test with generate_etrecord=False (default) + _, edge_manager, et_manager = self.get_test_model() + + # Verify that no ETRecord was generated + self.assertIsNone(edge_manager._etrecord) + + # Test get_etrecord method should raise RuntimeError + with self.assertRaises(RuntimeError): + et_manager.get_etrecord() + + def test_to_edge_etrecord_save_and_parse(self): + """Test that ETRecord generated by to_edge can be saved and parsed.""" + aten_program, _, et_manager = self.get_test_model(generate_etrecord=True) + + etrecord = et_manager.get_etrecord() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_to_edge.bin" + + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + # Note: Skip graph structure comparison due to transformation differences + self.check_graph_closeness( + etrecord.exported_program, parsed_etrecord.exported_program + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_manager.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_manager.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(aten_program.graph), + ) + def test_to_edge_transform_and_lower_etrecord_save_and_parse(self): """Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed.""" f = models.BasicSinMax() diff --git a/exir/program/_program.py b/exir/program/_program.py index 63b49d9860d..1ec63b3d204 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1351,6 +1351,7 @@ def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, + generate_etrecord: Optional[bool] = False, ) -> "EdgeProgramManager": """ :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in @@ -1363,6 +1364,8 @@ def to_edge( compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + generate_etrecord: An optional argument used to generate an etrecord for debugging purposes. + Returns: EdgeProgramManager """ @@ -1416,7 +1419,14 @@ def to_edge( logging.info(f"Input program {name} is not in Edge dialect.") raise e - return EdgeProgramManager(edge_programs, constant_methods, config) + epm = EdgeProgramManager(edge_programs, constant_methods, config) + if generate_etrecord: + etrecord = _create_empty_etrecord() + etrecord.add_exported_program(aten_programs) + etrecord.add_edge_dialect_program(copy.deepcopy(epm)) + epm._etrecord = etrecord + + return epm class EdgeProgramManager: