@@ -252,6 +252,73 @@ def test_etrecord_generation_with_exported_program(self):
252252 # Validate that export_graph_id matches the expected value
253253 self .assertEqual (etrecord .export_graph_id , expected_graph_id )
254254
255+ def test_add_extra_export_modules (self ):
256+ """Test add_extra_export_modules when ETRecord already has a graph_map."""
257+ captured_output , edge_output , et_output = self .get_test_model ()
258+
259+ # Create an ETRecord instance with existing graph_map
260+ initial_graph_map = {
261+ "existing_module/forward" : captured_output .exported_program
262+ }
263+ etrecord = ETRecord (
264+ exported_program = captured_output .exported_program ,
265+ export_graph_id = id (captured_output .exported_program .graph ),
266+ edge_dialect_program = edge_output .exported_program ,
267+ graph_map = initial_graph_map ,
268+ _debug_handle_map = et_output .debug_handle_map ,
269+ _delegate_map = et_output .delegate_map ,
270+ )
271+
272+ # Verify initial state
273+ self .assertIsNotNone (etrecord .graph_map )
274+ self .assertIn ("existing_module/forward" , etrecord .graph_map )
275+
276+ # Create additional module to add
277+ f2 = models .BasicSinMax ()
278+ captured_output2 = exir .capture (
279+ f2 , f2 .get_random_inputs (), exir .CaptureConfig ()
280+ )
281+
282+ extra_modules = {
283+ "new_module" : captured_output2 .exported_program ,
284+ }
285+
286+ # Add extra export modules
287+ etrecord .add_extra_export_modules (extra_modules )
288+
289+ # Verify both existing and new modules are present
290+ self .assertIn ("existing_module/forward" , etrecord .graph_map )
291+ self .assertIn ("new_module/forward" , etrecord .graph_map )
292+
293+ # Verify the modules are correctly stored
294+ self .check_graph_closeness (
295+ etrecord .graph_map ["existing_module/forward" ],
296+ captured_output .exported_program .graph_module ,
297+ )
298+ self .check_graph_closeness (
299+ etrecord .graph_map ["new_module/forward" ],
300+ captured_output2 .exported_program .graph_module ,
301+ )
302+
303+ def test_add_extra_export_modules_reserved_name_validation (self ):
304+ """Test that add_extra_export_modules validates reserved names."""
305+ captured_output , edge_output , et_output = self .get_test_model ()
306+
307+ etrecord = ETRecord (
308+ exported_program = captured_output .exported_program ,
309+ export_graph_id = id (captured_output .exported_program .graph ),
310+ edge_dialect_program = edge_output .exported_program ,
311+ _debug_handle_map = et_output .debug_handle_map ,
312+ _delegate_map = et_output .delegate_map ,
313+ )
314+
315+ # Test that reserved names are rejected
316+ for reserved_name in ETRecordReservedFileNames :
317+ with self .assertRaises (RuntimeError ):
318+ etrecord .add_extra_export_modules (
319+ {reserved_name : captured_output .exported_program }
320+ )
321+
255322 def test_etrecord_class_constructor_and_save (self ):
256323 """Test that ETRecord class constructor and save method work correctly."""
257324 captured_output , edge_output , et_output = self .get_test_model ()
0 commit comments