From 2202e7fc39302b58d9d337fc351baa021d8c9092 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Fri, 1 Aug 2025 11:22:32 -0700 Subject: [PATCH] add more export modules after ertrecod created Pull Request resolved: https://github.com/pytorch/executorch/pull/13010 we need to support etrecord recording custom export modules for further usage. This diff makes that happen by creating new function inside ETRecord ghstack-source-id: 300161197 Differential Revision: [D79279401](https://our.internmc.facebook.com/intern/diff/D79279401/) --- devtools/etrecord/_etrecord.py | 32 +++++++++++ devtools/etrecord/tests/etrecord_test.py | 67 ++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index e149aeab650..a727d3c172d 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -200,6 +200,38 @@ def _save_edge_dialect_program( f"{base_name}_example_inputs", serialized_artifact.example_inputs ) + def add_extra_export_modules( + self, + extra_recorded_export_modules: Dict[ + str, + Union[ + ExportedProgram, + ExirExportedProgram, + EdgeProgramManager, + ], + ], + ) -> None: + """ + Add extra export modules to the ETRecord after it has been created. + + This method allows users to add more export modules they want to record + to an existing ETRecord instance. The modules will be added to the graph_map + and will be included when the ETRecord is saved. + + Args: + extra_recorded_export_modules: A dictionary of graph modules with the key being + the user provided name and the value being the corresponding exported module. + The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. + """ + if self.graph_map is None: + self.graph_map = {} + + # Now self.graph_map is guaranteed to be non-None + graph_map = self.graph_map + for module_name, export_module in extra_recorded_export_modules.items(): + _validate_module_name(module_name) + _add_module_to_graph_map(graph_map, module_name, export_module) + def _get_reference_outputs( bundled_program: BundledProgram, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 9b9f3290162..28c3f87a243 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -252,6 +252,73 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_add_extra_export_modules(self): + """Test add_extra_export_modules when ETRecord already has a graph_map.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing graph_map + initial_graph_map = { + "existing_module/forward": captured_output.exported_program + } + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + graph_map=initial_graph_map, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state + self.assertIsNotNone(etrecord.graph_map) + self.assertIn("existing_module/forward", etrecord.graph_map) + + # Create additional module to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + extra_modules = { + "new_module": captured_output2.exported_program, + } + + # Add extra export modules + etrecord.add_extra_export_modules(extra_modules) + + # Verify both existing and new modules are present + self.assertIn("existing_module/forward", etrecord.graph_map) + self.assertIn("new_module/forward", etrecord.graph_map) + + # Verify the modules are correctly stored + self.check_graph_closeness( + etrecord.graph_map["existing_module/forward"], + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.graph_map["new_module/forward"], + captured_output2.exported_program.graph_module, + ) + + def test_add_extra_export_modules_reserved_name_validation(self): + """Test that add_extra_export_modules validates reserved names.""" + captured_output, edge_output, et_output = self.get_test_model() + + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Test that reserved names are rejected + for reserved_name in ETRecordReservedFileNames: + with self.assertRaises(RuntimeError): + etrecord.add_extra_export_modules( + {reserved_name: captured_output.exported_program} + ) + def test_etrecord_class_constructor_and_save(self): """Test that ETRecord class constructor and save method work correctly.""" captured_output, edge_output, et_output = self.get_test_model()