From b4fc97c2c59d9321ab0261eaa41dbb3e84f6ceb4 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 30 Jul 2025 16:04:30 -0700 Subject: [PATCH] add executorch program equipment support in etrecord class Previously we have to provide all essentail infos at the same time to generate etrecord; however if we want to generate it through export flow we can not find a stage that having all essential infos so that we need to have a new way to contruct it on-the-fly. This diff makes the target happen by adding three functions: `add_exported_program`, `add_edge_dialect_program` and `add_executorch_program` so that whenever we have the required info we can equip it into etrecord. Also update test case for test coverage. Differential Revision: [D79294945](https://our.internmc.facebook.com/intern/diff/D79294945/) [ghstack-poisoned] --- devtools/etrecord/_etrecord.py | 171 ++++-- devtools/etrecord/tests/etrecord_test.py | 636 +++++++++++++++++++++++ 2 files changed, 761 insertions(+), 46 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index a727d3c172d..3b8a71279fd 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -229,9 +229,122 @@ def add_extra_export_modules( # Now self.graph_map is guaranteed to be non-None graph_map = self.graph_map for module_name, export_module in extra_recorded_export_modules.items(): - _validate_module_name(module_name) _add_module_to_graph_map(graph_map, module_name, export_module) + def add_executorch_program( + self, + executorch_program: Union[ + ExecutorchProgram, + ExecutorchProgramManager, + BundledProgram, + ], + ) -> None: + """ + Add executorch program data to the ETRecord after it has been created. + + This method allows users to add executorch program data they want to record + to an existing ETRecord instance. The executorch program data includes debug handle map, + delegate map, reference outputs, and representative inputs that will be included + when the ETRecord is saved. + + Args: + executorch_program: The ExecuTorch program for this model returned by the call to + `to_executorch()` or the `BundledProgram` of this model. + + Raises: + RuntimeError: If executorch program data already exists in the ETRecord. + """ + # Check if executorch program data already exists + if ( + self._debug_handle_map is not None + or self._delegate_map is not None + or self._reference_outputs is not None + or self._representative_inputs is not None + ): + raise RuntimeError( + "Executorch program data already exists in the ETRecord. " + "Cannot add executorch program data when it already exists." + ) + + # Process executorch program and extract data + debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( + _process_executorch_program(executorch_program) + ) + + # Set the extracted data + self._debug_handle_map = debug_handle_map + self._delegate_map = delegate_map + self._reference_outputs = reference_outputs + self._representative_inputs = representative_inputs + + def add_exported_program( + self, + exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]], + ) -> None: + """ + Add exported program to the ETRecord after it has been created. + + This method allows users to add an exported program they want to record + to an existing ETRecord instance. The exported program will be included + when the ETRecord is saved. + + Args: + exported_program: The exported program for this model returned by the call to + `torch.export()` or a dictionary with method names as keys and exported programs as values. + Can be None, in which case no exported program data will be added. + + Raises: + RuntimeError: If exported program already exists in the ETRecord. + """ + # Check if exported program already exists + if self.exported_program is not None or self.export_graph_id is not None: + raise RuntimeError( + "Exported program already exists in the ETRecord. " + "Cannot add exported program when it already exists." + ) + + # Process exported program and extract data + processed_exported_program, export_graph_id = _process_exported_program( + exported_program + ) + + # Set the extracted data + self.exported_program = processed_exported_program + self.export_graph_id = export_graph_id + + def add_edge_dialect_program( + self, + edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram], + ) -> None: + """ + Add edge dialect program to the ETRecord after it has been created. + + This method allows users to add an edge dialect program they want to record + to an existing ETRecord instance. The edge dialect program will be included + when the ETRecord is saved. + + Args: + edge_dialect_program: The edge dialect program for this model returned by the call to + `to_edge()` or `EdgeProgramManager` for this model. + + Raises: + RuntimeError: If edge dialect program already exists in the ETRecord. + """ + # Check if edge dialect program already exists + if self.edge_dialect_program is not None: + raise RuntimeError( + "Edge dialect program already exists in the ETRecord. " + "Cannot add edge dialect program when it already exists." + ) + + # Process edge dialect program and extract data + processed_edge_dialect_program = _process_edge_dialect_program( + edge_dialect_program + ) + + # Set the extracted data + self.edge_dialect_program = processed_edge_dialect_program + def _get_reference_outputs( bundled_program: BundledProgram, @@ -317,37 +430,24 @@ def generate_etrecord( Returns: None """ - # Process all inputs and prepare data for ETRecord construction - processed_exported_program, export_graph_id = _process_exported_program( - exported_program - ) - graph_map = _process_extra_recorded_modules(extra_recorded_export_modules) - processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program) - debug_handle_map, delegate_map, reference_outputs, representative_inputs = ( - _process_executorch_program(executorch_program) - ) + etrecord = ETRecord() + etrecord.add_exported_program(exported_program) + etrecord.add_edge_dialect_program(edge_dialect_program) + etrecord.add_executorch_program(executorch_program) - # Create ETRecord instance and save - etrecord = ETRecord( - exported_program=processed_exported_program, - export_graph_id=export_graph_id, - edge_dialect_program=processed_edge_dialect_program, - graph_map=graph_map if graph_map else None, - _debug_handle_map=debug_handle_map, - _delegate_map=delegate_map, - _reference_outputs=reference_outputs, - _representative_inputs=representative_inputs, - ) + # Add extra export modules if user provided + if extra_recorded_export_modules is not None: + etrecord.add_extra_export_modules(extra_recorded_export_modules) etrecord.save(et_record) def _process_exported_program( exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]] -) -> tuple[Optional[ExportedProgram], int]: +) -> tuple[Optional[ExportedProgram], Optional[int]]: """Process exported program and return the processed program and export graph id.""" processed_exported_program = None - export_graph_id = 0 + export_graph_id = None if exported_program is not None: if isinstance(exported_program, dict) and "forward" in exported_program: @@ -361,29 +461,6 @@ def _process_exported_program( return processed_exported_program, export_graph_id -def _process_extra_recorded_modules( - extra_recorded_export_modules: Optional[ - Dict[ - str, - Union[ - ExportedProgram, - ExirExportedProgram, - EdgeProgramManager, - ], - ] - ] -) -> Dict[str, ExportedProgram]: - """Process extra recorded export modules and return graph map.""" - graph_map = {} - - if extra_recorded_export_modules is not None: - for module_name, export_module in extra_recorded_export_modules.items(): - _validate_module_name(module_name) - _add_module_to_graph_map(graph_map, module_name, export_module) - - return graph_map - - def _validate_module_name(module_name: str) -> None: """Validate that module name is not a reserved name.""" contains_reserved_name = any( @@ -401,6 +478,8 @@ def _add_module_to_graph_map( export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager], ) -> None: """Add export module to graph map based on its type.""" + _validate_module_name(module_name) + if isinstance(export_module, ExirExportedProgram): graph_map[f"{module_name}/forward"] = export_module.exported_program elif isinstance(export_module, ExportedProgram): diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 28c3f87a243..0720caeac54 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -29,6 +29,75 @@ # TODO : T154728484 Add test cases to cover multiple entry points class TestETRecord(unittest.TestCase): + def assert_etrecord_has_no_exported_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no exported program data.""" + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + def assert_etrecord_has_no_edge_dialect_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no edge dialect program data.""" + self.assertIsNone(etrecord.edge_dialect_program) + + def assert_etrecord_has_no_executorch_program(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no executorch program data.""" + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None: + """Assert that ETRecord has no data at all.""" + self.assert_etrecord_has_no_exported_program(etrecord) + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + self.assert_etrecord_has_no_executorch_program(etrecord) + self.assertIsNone(etrecord.graph_map) + + def assert_etrecord_has_exported_program( + self, etrecord: ETRecord, expected_exported_program + ) -> None: + """Assert that ETRecord has exported program data matching expected.""" + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + expected_exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(expected_exported_program.graph), + ) + + def assert_etrecord_has_edge_dialect_program( + self, etrecord: ETRecord, expected_edge_program + ) -> None: + """Assert that ETRecord has edge dialect program data matching expected.""" + self.assertIsNotNone(etrecord.edge_dialect_program) + if hasattr(expected_edge_program, "exported_program"): + # EdgeProgramManager case + expected_graph_module = expected_edge_program.exported_program.graph_module + else: + # ExirExportedProgram case + expected_graph_module = expected_edge_program.exported_program.graph_module + self.check_graph_closeness( + etrecord.edge_dialect_program, + expected_graph_module, + ) + + def assert_etrecord_has_executorch_program( + self, etrecord: ETRecord, expected_et_output + ) -> None: + """Assert that ETRecord has executorch program data matching expected.""" + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(expected_et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(expected_et_output.delegate_map)), + ) + def get_test_model(self): f = models.BasicSinMax() captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) @@ -473,3 +542,570 @@ def test_etrecord_generation_with_exported_program_dict(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + + def test_add_executorch_program(self): + """Test add_executorch_program when ETRecord has no existing executorch program data.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assert_etrecord_has_no_executorch_program(etrecord) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + # For regular ExecutorchProgram, reference_outputs and representative_inputs should be None + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + def test_add_executorch_program_with_bundled_program(self): + """Test add_executorch_program with BundledProgram.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Verify initial state - no executorch program data + self.assertIsNone(etrecord._debug_handle_map) + self.assertIsNone(etrecord._delegate_map) + self.assertIsNone(etrecord._reference_outputs) + self.assertIsNone(etrecord._representative_inputs) + + # Add bundled program + etrecord.add_executorch_program(bundled_program) + + # Verify executorch program data is now present + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + self.assertIsNotNone(etrecord._reference_outputs) + self.assertIsNotNone(etrecord._representative_inputs) + + # Verify the data matches expected values + expected_reference_outputs = _get_reference_outputs(bundled_program) + expected_representative_inputs = _get_representative_inputs(bundled_program) + + # Compare reference outputs + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][0][0], + expected_reference_outputs["forward"][0][0], + ) + ) + self.assertTrue( + torch.equal( + etrecord._reference_outputs["forward"][1][0], + expected_reference_outputs["forward"][1][0], + ) + ) + + # Compare representative inputs + for expected, actual in zip( + etrecord._representative_inputs, expected_representative_inputs + ): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + + def test_add_executorch_program_already_exists_exception(self): + """Test that add_executorch_program raises exception when executorch program data already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding executorch program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_partial_data_exists_exception(self): + """Test that add_executorch_program raises exception when partial executorch program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only debug_handle_map (partial data) + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + ) + + # Verify that adding executorch program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_executorch_program(et_output) + + self.assertIn( + "Executorch program data already exists in the ETRecord", + str(context.exception), + ) + + def test_add_executorch_program_and_save(self): + """Test that ETRecord with added executorch program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without executorch program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + ) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate executorch program data + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program(self): + """Test add_exported_program when ETRecord has no existing exported program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_with_dict(self): + """Test add_exported_program with dictionary input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assertIsNone(etrecord.exported_program) + self.assertIsNone(etrecord.export_graph_id) + + # Add exported program as dictionary + exported_program_dict = {"forward": captured_output.exported_program} + etrecord.add_exported_program(exported_program_dict) + + # Verify exported program is now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_exported_program_already_exists_exception(self): + """Test that add_exported_program raises exception when exported program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing exported program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another exported program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + + # Verify that adding exported program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output2.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_partial_data_exists_exception(self): + """Test that add_exported_program raises exception when partial exported program data exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with only export_graph_id (partial data) + etrecord = ETRecord( + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify that adding exported program raises RuntimeError even with partial data + with self.assertRaises(RuntimeError) as context: + etrecord.add_exported_program(captured_output.exported_program) + + self.assertIn( + "Exported program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_exported_program_with_none(self): + """Test add_exported_program with None input.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no exported program + self.assert_etrecord_has_no_exported_program(etrecord) + + # Add None exported program (should not raise error) + etrecord.add_exported_program(None) + + # Verify exported program is still None + self.assert_etrecord_has_no_exported_program(etrecord) + + def test_add_exported_program_and_save(self): + """Test that ETRecord with added exported program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without exported program + etrecord = ETRecord( + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_exported_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_edge_dialect_program(self): + """Test add_edge_dialect_program when ETRecord has no existing edge dialect program.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assert_etrecord_has_no_edge_dialect_program(etrecord) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_with_exir_exported_program(self): + """Test add_edge_dialect_program with ExirExportedProgram.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Verify initial state - no edge dialect program + self.assertIsNone(etrecord.edge_dialect_program) + + # Create ExirExportedProgram from captured output + exir_exported_program = captured_output.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Add edge dialect program using ExirExportedProgram + etrecord.add_edge_dialect_program(exir_exported_program) + + # Verify edge dialect program is now present + self.assertIsNotNone(etrecord.edge_dialect_program) + self.check_graph_closeness( + etrecord.edge_dialect_program, + exir_exported_program.exported_program.graph_module, + ) + + def test_add_edge_dialect_program_already_exists_exception(self): + """Test that add_edge_dialect_program raises exception when edge dialect program already exists.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance with existing edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Create another edge program to try to add + f2 = models.BasicSinMax() + captured_output2 = exir.capture( + f2, f2.get_random_inputs(), exir.CaptureConfig() + ) + edge_output2 = captured_output2.to_edge( + exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False) + ) + + # Verify that adding edge dialect program raises RuntimeError + with self.assertRaises(RuntimeError) as context: + etrecord.add_edge_dialect_program(edge_output2) + + self.assertIn( + "Edge dialect program already exists in the ETRecord", + str(context.exception), + ) + + def test_add_edge_dialect_program_and_save(self): + """Test that ETRecord with added edge dialect program can be saved and parsed correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an ETRecord instance without edge dialect program + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_with_added_edge_program.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate export graph id + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + + def test_add_all_programs_sequentially(self): + """Test adding all programs sequentially to an empty ETRecord.""" + captured_output, edge_output, et_output = self.get_test_model() + + # Create an empty ETRecord instance + etrecord = ETRecord() + + # Verify initial state - everything is None + self.assert_etrecord_is_empty(etrecord) + + # Add exported program + etrecord.add_exported_program(captured_output.exported_program) + + # Add edge dialect program + etrecord.add_edge_dialect_program(edge_output) + + # Add executorch program + etrecord.add_executorch_program(et_output) + + # Verify all components are now present + self.assertIsNotNone(etrecord.exported_program) + self.assertIsNotNone(etrecord.export_graph_id) + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsNotNone(etrecord._debug_handle_map) + self.assertIsNotNone(etrecord._delegate_map) + + # Verify the data matches expected values + self.check_graph_closeness( + etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + self.check_graph_closeness( + etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + self.assertEqual( + etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Test that the complete ETRecord can be saved and parsed + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_complete.bin" + + # Save the ETRecord + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + captured_output.exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate all metadata + self.assertEqual( + parsed_etrecord.export_graph_id, + id(captured_output.exported_program.graph), + ) + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + )