@@ -200,6 +200,151 @@ def _save_edge_dialect_program(
200
200
f"{ base_name } _example_inputs" , serialized_artifact .example_inputs
201
201
)
202
202
203
+ def add_extra_export_modules (
204
+ self ,
205
+ extra_recorded_export_modules : Dict [
206
+ str ,
207
+ Union [
208
+ ExportedProgram ,
209
+ ExirExportedProgram ,
210
+ EdgeProgramManager ,
211
+ ],
212
+ ],
213
+ ) -> None :
214
+ """
215
+ Add extra export modules to the ETRecord after it has been created.
216
+
217
+ This method allows users to add more export modules they want to record
218
+ to an existing ETRecord instance. The modules will be added to the graph_map
219
+ and will be included when the ETRecord is saved.
220
+
221
+ Args:
222
+ extra_recorded_export_modules: A dictionary of graph modules with the key being
223
+ the user provided name and the value being the corresponding exported module.
224
+ The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
225
+ """
226
+ if self .graph_map is None :
227
+ self .graph_map = {}
228
+
229
+ # Now self.graph_map is guaranteed to be non-None
230
+ graph_map = self .graph_map
231
+ for module_name , export_module in extra_recorded_export_modules .items ():
232
+ _add_module_to_graph_map (graph_map , module_name , export_module )
233
+
234
+ def add_executorch_program (
235
+ self ,
236
+ executorch_program : Union [
237
+ ExecutorchProgram ,
238
+ ExecutorchProgramManager ,
239
+ BundledProgram ,
240
+ ],
241
+ ) -> None :
242
+ """
243
+ Add executorch program data to the ETRecord after it has been created.
244
+
245
+ This method allows users to add executorch program data they want to record
246
+ to an existing ETRecord instance. The executorch program data includes debug handle map,
247
+ delegate map, reference outputs, and representative inputs that will be included
248
+ when the ETRecord is saved.
249
+
250
+ Args:
251
+ executorch_program: The ExecuTorch program for this model returned by the call to
252
+ `to_executorch()` or the `BundledProgram` of this model.
253
+
254
+ Raises:
255
+ RuntimeError: If executorch program data already exists in the ETRecord.
256
+ """
257
+ # Check if executorch program data already exists
258
+ if (
259
+ self ._debug_handle_map is not None
260
+ or self ._delegate_map is not None
261
+ or self ._reference_outputs is not None
262
+ or self ._representative_inputs is not None
263
+ ):
264
+ raise RuntimeError (
265
+ "Executorch program data already exists in the ETRecord. "
266
+ "Cannot add executorch program data when it already exists."
267
+ )
268
+
269
+ # Process executorch program and extract data
270
+ debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
271
+ _process_executorch_program (executorch_program )
272
+ )
273
+
274
+ # Set the extracted data
275
+ self ._debug_handle_map = debug_handle_map
276
+ self ._delegate_map = delegate_map
277
+ self ._reference_outputs = reference_outputs
278
+ self ._representative_inputs = representative_inputs
279
+
280
+ def add_exported_program (
281
+ self ,
282
+ exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]],
283
+ ) -> None :
284
+ """
285
+ Add exported program to the ETRecord after it has been created.
286
+
287
+ This method allows users to add an exported program they want to record
288
+ to an existing ETRecord instance. The exported program will be included
289
+ when the ETRecord is saved.
290
+
291
+ Args:
292
+ exported_program: The exported program for this model returned by the call to
293
+ `torch.export()` or a dictionary with method names as keys and exported programs as values.
294
+ Can be None, in which case no exported program data will be added.
295
+
296
+ Raises:
297
+ RuntimeError: If exported program already exists in the ETRecord.
298
+ """
299
+ # Check if exported program already exists
300
+ if self .exported_program is not None or self .export_graph_id is not None :
301
+ raise RuntimeError (
302
+ "Exported program already exists in the ETRecord. "
303
+ "Cannot add exported program when it already exists."
304
+ )
305
+
306
+ # Process exported program and extract data
307
+ processed_exported_program , export_graph_id = _process_exported_program (
308
+ exported_program
309
+ )
310
+
311
+ # Set the extracted data
312
+ self .exported_program = processed_exported_program
313
+ self .export_graph_id = export_graph_id
314
+
315
+ def add_edge_dialect_program (
316
+ self ,
317
+ edge_dialect_program : Union [EdgeProgramManager , ExirExportedProgram ],
318
+ ) -> None :
319
+ """
320
+ Add edge dialect program to the ETRecord after it has been created.
321
+
322
+ This method allows users to add an edge dialect program they want to record
323
+ to an existing ETRecord instance. The edge dialect program will be included
324
+ when the ETRecord is saved.
325
+
326
+ Args:
327
+ edge_dialect_program: The edge dialect program for this model returned by the call to
328
+ `to_edge()` or `EdgeProgramManager` for this model.
329
+
330
+ Raises:
331
+ RuntimeError: If edge dialect program already exists in the ETRecord.
332
+ """
333
+ # Check if edge dialect program already exists
334
+ if self .edge_dialect_program is not None :
335
+ raise RuntimeError (
336
+ "Edge dialect program already exists in the ETRecord. "
337
+ "Cannot add edge dialect program when it already exists."
338
+ )
339
+
340
+ # Process edge dialect program and extract data
341
+ processed_edge_dialect_program = _process_edge_dialect_program (
342
+ edge_dialect_program
343
+ )
344
+
345
+ # Set the extracted data
346
+ self .edge_dialect_program = processed_edge_dialect_program
347
+
203
348
204
349
def _get_reference_outputs (
205
350
bundled_program : BundledProgram ,
@@ -285,37 +430,24 @@ def generate_etrecord(
285
430
Returns:
286
431
None
287
432
"""
288
- # Process all inputs and prepare data for ETRecord construction
289
- processed_exported_program , export_graph_id = _process_exported_program (
290
- exported_program
291
- )
292
- graph_map = _process_extra_recorded_modules (extra_recorded_export_modules )
293
- processed_edge_dialect_program = _process_edge_dialect_program (edge_dialect_program )
294
- debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
295
- _process_executorch_program (executorch_program )
296
- )
433
+ etrecord = ETRecord ()
434
+ etrecord .add_exported_program (exported_program )
435
+ etrecord .add_edge_dialect_program (edge_dialect_program )
436
+ etrecord .add_executorch_program (executorch_program )
297
437
298
- # Create ETRecord instance and save
299
- etrecord = ETRecord (
300
- exported_program = processed_exported_program ,
301
- export_graph_id = export_graph_id ,
302
- edge_dialect_program = processed_edge_dialect_program ,
303
- graph_map = graph_map if graph_map else None ,
304
- _debug_handle_map = debug_handle_map ,
305
- _delegate_map = delegate_map ,
306
- _reference_outputs = reference_outputs ,
307
- _representative_inputs = representative_inputs ,
308
- )
438
+ # Add extra export modules if user provided
439
+ if extra_recorded_export_modules is not None :
440
+ etrecord .add_extra_export_modules (extra_recorded_export_modules )
309
441
310
442
etrecord .save (et_record )
311
443
312
444
313
445
def _process_exported_program (
314
446
exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]]
315
- ) -> tuple [Optional [ExportedProgram ], int ]:
447
+ ) -> tuple [Optional [ExportedProgram ], Optional [ int ] ]:
316
448
"""Process exported program and return the processed program and export graph id."""
317
449
processed_exported_program = None
318
- export_graph_id = 0
450
+ export_graph_id = None
319
451
320
452
if exported_program is not None :
321
453
if isinstance (exported_program , dict ) and "forward" in exported_program :
@@ -329,29 +461,6 @@ def _process_exported_program(
329
461
return processed_exported_program , export_graph_id
330
462
331
463
332
- def _process_extra_recorded_modules (
333
- extra_recorded_export_modules : Optional [
334
- Dict [
335
- str ,
336
- Union [
337
- ExportedProgram ,
338
- ExirExportedProgram ,
339
- EdgeProgramManager ,
340
- ],
341
- ]
342
- ]
343
- ) -> Dict [str , ExportedProgram ]:
344
- """Process extra recorded export modules and return graph map."""
345
- graph_map = {}
346
-
347
- if extra_recorded_export_modules is not None :
348
- for module_name , export_module in extra_recorded_export_modules .items ():
349
- _validate_module_name (module_name )
350
- _add_module_to_graph_map (graph_map , module_name , export_module )
351
-
352
- return graph_map
353
-
354
-
355
464
def _validate_module_name (module_name : str ) -> None :
356
465
"""Validate that module name is not a reserved name."""
357
466
contains_reserved_name = any (
@@ -369,6 +478,8 @@ def _add_module_to_graph_map(
369
478
export_module : Union [ExportedProgram , ExirExportedProgram , EdgeProgramManager ],
370
479
) -> None :
371
480
"""Add export module to graph map based on its type."""
481
+ _validate_module_name (module_name )
482
+
372
483
if isinstance (export_module , ExirExportedProgram ):
373
484
graph_map [f"{ module_name } /forward" ] = export_module .exported_program
374
485
elif isinstance (export_module , ExportedProgram ):
0 commit comments