Skip to content

Commit 2202e7f

Browse files
committed
add more export modules after ertrecod created
Pull Request resolved: #13010 we need to support etrecord recording custom export modules for further usage. This diff makes that happen by creating new function inside ETRecord ghstack-source-id: 300161197 Differential Revision: [D79279401](https://our.internmc.facebook.com/intern/diff/D79279401/)
1 parent 8651d31 commit 2202e7f

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,38 @@ def _save_edge_dialect_program(
200200
f"{base_name}_example_inputs", serialized_artifact.example_inputs
201201
)
202202

203+
def add_extra_export_modules(
204+
self,
205+
extra_recorded_export_modules: Dict[
206+
str,
207+
Union[
208+
ExportedProgram,
209+
ExirExportedProgram,
210+
EdgeProgramManager,
211+
],
212+
],
213+
) -> None:
214+
"""
215+
Add extra export modules to the ETRecord after it has been created.
216+
217+
This method allows users to add more export modules they want to record
218+
to an existing ETRecord instance. The modules will be added to the graph_map
219+
and will be included when the ETRecord is saved.
220+
221+
Args:
222+
extra_recorded_export_modules: A dictionary of graph modules with the key being
223+
the user provided name and the value being the corresponding exported module.
224+
The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
225+
"""
226+
if self.graph_map is None:
227+
self.graph_map = {}
228+
229+
# Now self.graph_map is guaranteed to be non-None
230+
graph_map = self.graph_map
231+
for module_name, export_module in extra_recorded_export_modules.items():
232+
_validate_module_name(module_name)
233+
_add_module_to_graph_map(graph_map, module_name, export_module)
234+
203235

204236
def _get_reference_outputs(
205237
bundled_program: BundledProgram,

devtools/etrecord/tests/etrecord_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,73 @@ def test_etrecord_generation_with_exported_program(self):
252252
# Validate that export_graph_id matches the expected value
253253
self.assertEqual(etrecord.export_graph_id, expected_graph_id)
254254

255+
def test_add_extra_export_modules(self):
256+
"""Test add_extra_export_modules when ETRecord already has a graph_map."""
257+
captured_output, edge_output, et_output = self.get_test_model()
258+
259+
# Create an ETRecord instance with existing graph_map
260+
initial_graph_map = {
261+
"existing_module/forward": captured_output.exported_program
262+
}
263+
etrecord = ETRecord(
264+
exported_program=captured_output.exported_program,
265+
export_graph_id=id(captured_output.exported_program.graph),
266+
edge_dialect_program=edge_output.exported_program,
267+
graph_map=initial_graph_map,
268+
_debug_handle_map=et_output.debug_handle_map,
269+
_delegate_map=et_output.delegate_map,
270+
)
271+
272+
# Verify initial state
273+
self.assertIsNotNone(etrecord.graph_map)
274+
self.assertIn("existing_module/forward", etrecord.graph_map)
275+
276+
# Create additional module to add
277+
f2 = models.BasicSinMax()
278+
captured_output2 = exir.capture(
279+
f2, f2.get_random_inputs(), exir.CaptureConfig()
280+
)
281+
282+
extra_modules = {
283+
"new_module": captured_output2.exported_program,
284+
}
285+
286+
# Add extra export modules
287+
etrecord.add_extra_export_modules(extra_modules)
288+
289+
# Verify both existing and new modules are present
290+
self.assertIn("existing_module/forward", etrecord.graph_map)
291+
self.assertIn("new_module/forward", etrecord.graph_map)
292+
293+
# Verify the modules are correctly stored
294+
self.check_graph_closeness(
295+
etrecord.graph_map["existing_module/forward"],
296+
captured_output.exported_program.graph_module,
297+
)
298+
self.check_graph_closeness(
299+
etrecord.graph_map["new_module/forward"],
300+
captured_output2.exported_program.graph_module,
301+
)
302+
303+
def test_add_extra_export_modules_reserved_name_validation(self):
304+
"""Test that add_extra_export_modules validates reserved names."""
305+
captured_output, edge_output, et_output = self.get_test_model()
306+
307+
etrecord = ETRecord(
308+
exported_program=captured_output.exported_program,
309+
export_graph_id=id(captured_output.exported_program.graph),
310+
edge_dialect_program=edge_output.exported_program,
311+
_debug_handle_map=et_output.debug_handle_map,
312+
_delegate_map=et_output.delegate_map,
313+
)
314+
315+
# Test that reserved names are rejected
316+
for reserved_name in ETRecordReservedFileNames:
317+
with self.assertRaises(RuntimeError):
318+
etrecord.add_extra_export_modules(
319+
{reserved_name: captured_output.exported_program}
320+
)
321+
255322
def test_etrecord_class_constructor_and_save(self):
256323
"""Test that ETRecord class constructor and save method work correctly."""
257324
captured_output, edge_output, et_output = self.get_test_model()

0 commit comments

Comments
 (0)