@@ -229,9 +229,122 @@ def add_extra_export_modules(
229229 # Now self.graph_map is guaranteed to be non-None
230230 graph_map = self .graph_map
231231 for module_name , export_module in extra_recorded_export_modules .items ():
232- _validate_module_name (module_name )
233232 _add_module_to_graph_map (graph_map , module_name , export_module )
234233
234+ def add_executorch_program (
235+ self ,
236+ executorch_program : Union [
237+ ExecutorchProgram ,
238+ ExecutorchProgramManager ,
239+ BundledProgram ,
240+ ],
241+ ) -> None :
242+ """
243+ Add executorch program data to the ETRecord after it has been created.
244+
245+ This method allows users to add executorch program data they want to record
246+ to an existing ETRecord instance. The executorch program data includes debug handle map,
247+ delegate map, reference outputs, and representative inputs that will be included
248+ when the ETRecord is saved.
249+
250+ Args:
251+ executorch_program: The ExecuTorch program for this model returned by the call to
252+ `to_executorch()` or the `BundledProgram` of this model.
253+
254+ Raises:
255+ RuntimeError: If executorch program data already exists in the ETRecord.
256+ """
257+ # Check if executorch program data already exists
258+ if (
259+ self ._debug_handle_map is not None
260+ or self ._delegate_map is not None
261+ or self ._reference_outputs is not None
262+ or self ._representative_inputs is not None
263+ ):
264+ raise RuntimeError (
265+ "Executorch program data already exists in the ETRecord. "
266+ "Cannot add executorch program data when it already exists."
267+ )
268+
269+ # Process executorch program and extract data
270+ debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
271+ _process_executorch_program (executorch_program )
272+ )
273+
274+ # Set the extracted data
275+ self ._debug_handle_map = debug_handle_map
276+ self ._delegate_map = delegate_map
277+ self ._reference_outputs = reference_outputs
278+ self ._representative_inputs = representative_inputs
279+
280+ def add_exported_program (
281+ self ,
282+ exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]],
283+ ) -> None :
284+ """
285+ Add exported program to the ETRecord after it has been created.
286+
287+ This method allows users to add an exported program they want to record
288+ to an existing ETRecord instance. The exported program will be included
289+ when the ETRecord is saved.
290+
291+ Args:
292+ exported_program: The exported program for this model returned by the call to
293+ `torch.export()` or a dictionary with method names as keys and exported programs as values.
294+ Can be None, in which case no exported program data will be added.
295+
296+ Raises:
297+ RuntimeError: If exported program already exists in the ETRecord.
298+ """
299+ # Check if exported program already exists
300+ if self .exported_program is not None or self .export_graph_id is not None :
301+ raise RuntimeError (
302+ "Exported program already exists in the ETRecord. "
303+ "Cannot add exported program when it already exists."
304+ )
305+
306+ # Process exported program and extract data
307+ processed_exported_program , export_graph_id = _process_exported_program (
308+ exported_program
309+ )
310+
311+ # Set the extracted data
312+ self .exported_program = processed_exported_program
313+ self .export_graph_id = export_graph_id
314+
315+ def add_edge_dialect_program (
316+ self ,
317+ edge_dialect_program : Union [EdgeProgramManager , ExirExportedProgram ],
318+ ) -> None :
319+ """
320+ Add edge dialect program to the ETRecord after it has been created.
321+
322+ This method allows users to add an edge dialect program they want to record
323+ to an existing ETRecord instance. The edge dialect program will be included
324+ when the ETRecord is saved.
325+
326+ Args:
327+ edge_dialect_program: The edge dialect program for this model returned by the call to
328+ `to_edge()` or `EdgeProgramManager` for this model.
329+
330+ Raises:
331+ RuntimeError: If edge dialect program already exists in the ETRecord.
332+ """
333+ # Check if edge dialect program already exists
334+ if self .edge_dialect_program is not None :
335+ raise RuntimeError (
336+ "Edge dialect program already exists in the ETRecord. "
337+ "Cannot add edge dialect program when it already exists."
338+ )
339+
340+ # Process edge dialect program and extract data
341+ processed_edge_dialect_program = _process_edge_dialect_program (
342+ edge_dialect_program
343+ )
344+
345+ # Set the extracted data
346+ self .edge_dialect_program = processed_edge_dialect_program
347+
235348
236349def _get_reference_outputs (
237350 bundled_program : BundledProgram ,
@@ -317,37 +430,24 @@ def generate_etrecord(
317430 Returns:
318431 None
319432 """
320- # Process all inputs and prepare data for ETRecord construction
321- processed_exported_program , export_graph_id = _process_exported_program (
322- exported_program
323- )
324- graph_map = _process_extra_recorded_modules (extra_recorded_export_modules )
325- processed_edge_dialect_program = _process_edge_dialect_program (edge_dialect_program )
326- debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
327- _process_executorch_program (executorch_program )
328- )
433+ etrecord = ETRecord ()
434+ etrecord .add_exported_program (exported_program )
435+ etrecord .add_edge_dialect_program (edge_dialect_program )
436+ etrecord .add_executorch_program (executorch_program )
329437
330- # Create ETRecord instance and save
331- etrecord = ETRecord (
332- exported_program = processed_exported_program ,
333- export_graph_id = export_graph_id ,
334- edge_dialect_program = processed_edge_dialect_program ,
335- graph_map = graph_map if graph_map else None ,
336- _debug_handle_map = debug_handle_map ,
337- _delegate_map = delegate_map ,
338- _reference_outputs = reference_outputs ,
339- _representative_inputs = representative_inputs ,
340- )
438+ # Add extra export modules if user provided
439+ if extra_recorded_export_modules is not None :
440+ etrecord .add_extra_export_modules (extra_recorded_export_modules )
341441
342442 etrecord .save (et_record )
343443
344444
345445def _process_exported_program (
346446 exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]]
347- ) -> tuple [Optional [ExportedProgram ], int ]:
447+ ) -> tuple [Optional [ExportedProgram ], Optional [ int ] ]:
348448 """Process exported program and return the processed program and export graph id."""
349449 processed_exported_program = None
350- export_graph_id = 0
450+ export_graph_id = None
351451
352452 if exported_program is not None :
353453 if isinstance (exported_program , dict ) and "forward" in exported_program :
@@ -361,29 +461,6 @@ def _process_exported_program(
361461 return processed_exported_program , export_graph_id
362462
363463
364- def _process_extra_recorded_modules (
365- extra_recorded_export_modules : Optional [
366- Dict [
367- str ,
368- Union [
369- ExportedProgram ,
370- ExirExportedProgram ,
371- EdgeProgramManager ,
372- ],
373- ]
374- ]
375- ) -> Dict [str , ExportedProgram ]:
376- """Process extra recorded export modules and return graph map."""
377- graph_map = {}
378-
379- if extra_recorded_export_modules is not None :
380- for module_name , export_module in extra_recorded_export_modules .items ():
381- _validate_module_name (module_name )
382- _add_module_to_graph_map (graph_map , module_name , export_module )
383-
384- return graph_map
385-
386-
387464def _validate_module_name (module_name : str ) -> None :
388465 """Validate that module name is not a reserved name."""
389466 contains_reserved_name = any (
@@ -401,6 +478,8 @@ def _add_module_to_graph_map(
401478 export_module : Union [ExportedProgram , ExirExportedProgram , EdgeProgramManager ],
402479) -> None :
403480 """Add export module to graph map based on its type."""
481+ _validate_module_name (module_name )
482+
404483 if isinstance (export_module , ExirExportedProgram ):
405484 graph_map [f"{ module_name } /forward" ] = export_module .exported_program
406485 elif isinstance (export_module , ExportedProgram ):
0 commit comments