From a8114d3599c2cca88dbad8c6c7795719dd24d038 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 30 Jul 2025 12:14:03 -0700 Subject: [PATCH] add more export modules after ertrecod created we need to support etrecord recording custom export modules for further usage. This diff makes that happen by creating new function inside ETRecord Differential Revision: [D79279401](https://our.internmc.facebook.com/intern/diff/D79279401/) [ghstack-poisoned] --- devtools/etrecord/_etrecord.py | 32 +++++++++++ devtools/etrecord/tests/etrecord_test.py | 67 ++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index e149aeab650..a727d3c172d 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -200,6 +200,38 @@ def _save_edge_dialect_program( f"{base_name}_example_inputs", serialized_artifact.example_inputs ) + def add_extra_export_modules( + self, + extra_recorded_export_modules: Dict[ + str, + Union[ + ExportedProgram, + ExirExportedProgram, + EdgeProgramManager, + ], + ], + ) -> None: + """ + Add extra export modules to the ETRecord after it has been created. + + This method allows users to add more export modules they want to record + to an existing ETRecord instance. The modules will be added to the graph_map + and will be included when the ETRecord is saved. + + Args: + extra_recorded_export_modules: A dictionary of graph modules with the key being + the user provided name and the value being the corresponding exported module. + The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`. + """ + if self.graph_map is None: + self.graph_map = {} + + # Now self.graph_map is guaranteed to be non-None + graph_map = self.graph_map + for module_name, export_module in extra_recorded_export_modules.items(): + _validate_module_name(module_name) + _add_module_to_graph_map(graph_map, module_name, export_module) + def _get_reference_outputs( bundled_program: BundledProgram, diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 9b9f3290162..28c3f87a243 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -252,6 +252,73 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_add_extra_export_modules(self): + """Test add_extra_export_modules when ETRecord already has a graph_map.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing graph_map + initial_graph_map = { + "existing_module/forward": captured_output.exported_program + } + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + graph_map=initial_graph_map, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state + self.assertIsNotNone(etrecord.graph_map) + self.assertIn("existing_module/forward", etrecord.graph_map) + + # Create additional module to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + extra_modules = { + "new_module": captured_output2.exported_program, + } + + # Add extra export modules + etrecord.add_extra_export_modules(extra_modules) + + # Verify both existing and new modules are present + self.assertIn("existing_module/forward", etrecord.graph_map) + self.assertIn("new_module/forward", etrecord.graph_map) + + # Verify the modules are correctly stored + self.check_graph_closeness( + etrecord.graph_map["existing_module/forward"], + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.graph_map["new_module/forward"], + captured_output2.exported_program.graph_module, + ) + + def test_add_extra_export_modules_reserved_name_validation(self): + """Test that add_extra_export_modules validates reserved names.""" + captured_output, edge_output, et_output = self.get_test_model() + + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Test that reserved names are rejected + for reserved_name in ETRecordReservedFileNames: + with self.assertRaises(RuntimeError): + etrecord.add_extra_export_modules( + {reserved_name: captured_output.exported_program} + ) + def test_etrecord_class_constructor_and_save(self): """Test that ETRecord class constructor and save method work correctly.""" captured_output, edge_output, et_output = self.get_test_model()