@@ -200,6 +200,151 @@ def _save_edge_dialect_program(
200200 f"{ base_name } _example_inputs" , serialized_artifact .example_inputs
201201 )
202202
203+ def add_extra_export_modules (
204+ self ,
205+ extra_recorded_export_modules : Dict [
206+ str ,
207+ Union [
208+ ExportedProgram ,
209+ ExirExportedProgram ,
210+ EdgeProgramManager ,
211+ ],
212+ ],
213+ ) -> None :
214+ """
215+ Add extra export modules to the ETRecord after it has been created.
216+
217+ This method allows users to add more export modules they want to record
218+ to an existing ETRecord instance. The modules will be added to the graph_map
219+ and will be included when the ETRecord is saved.
220+
221+ Args:
222+ extra_recorded_export_modules: A dictionary of graph modules with the key being
223+ the user provided name and the value being the corresponding exported module.
224+ The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
225+ """
226+ if self .graph_map is None :
227+ self .graph_map = {}
228+
229+ # Now self.graph_map is guaranteed to be non-None
230+ graph_map = self .graph_map
231+ for module_name , export_module in extra_recorded_export_modules .items ():
232+ _add_module_to_graph_map (graph_map , module_name , export_module )
233+
234+ def add_executorch_program (
235+ self ,
236+ executorch_program : Union [
237+ ExecutorchProgram ,
238+ ExecutorchProgramManager ,
239+ BundledProgram ,
240+ ],
241+ ) -> None :
242+ """
243+ Add executorch program data to the ETRecord after it has been created.
244+
245+ This method allows users to add executorch program data they want to record
246+ to an existing ETRecord instance. The executorch program data includes debug handle map,
247+ delegate map, reference outputs, and representative inputs that will be included
248+ when the ETRecord is saved.
249+
250+ Args:
251+ executorch_program: The ExecuTorch program for this model returned by the call to
252+ `to_executorch()` or the `BundledProgram` of this model.
253+
254+ Raises:
255+ RuntimeError: If executorch program data already exists in the ETRecord.
256+ """
257+ # Check if executorch program data already exists
258+ if (
259+ self ._debug_handle_map is not None
260+ or self ._delegate_map is not None
261+ or self ._reference_outputs is not None
262+ or self ._representative_inputs is not None
263+ ):
264+ raise RuntimeError (
265+ "Executorch program data already exists in the ETRecord. "
266+ "Cannot add executorch program data when it already exists."
267+ )
268+
269+ # Process executorch program and extract data
270+ debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
271+ _process_executorch_program (executorch_program )
272+ )
273+
274+ # Set the extracted data
275+ self ._debug_handle_map = debug_handle_map
276+ self ._delegate_map = delegate_map
277+ self ._reference_outputs = reference_outputs
278+ self ._representative_inputs = representative_inputs
279+
280+ def add_exported_program (
281+ self ,
282+ exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]],
283+ ) -> None :
284+ """
285+ Add exported program to the ETRecord after it has been created.
286+
287+ This method allows users to add an exported program they want to record
288+ to an existing ETRecord instance. The exported program will be included
289+ when the ETRecord is saved.
290+
291+ Args:
292+ exported_program: The exported program for this model returned by the call to
293+ `torch.export()` or a dictionary with method names as keys and exported programs as values.
294+ Can be None, in which case no exported program data will be added.
295+
296+ Raises:
297+ RuntimeError: If exported program already exists in the ETRecord.
298+ """
299+ # Check if exported program already exists
300+ if self .exported_program is not None or self .export_graph_id is not None :
301+ raise RuntimeError (
302+ "Exported program already exists in the ETRecord. "
303+ "Cannot add exported program when it already exists."
304+ )
305+
306+ # Process exported program and extract data
307+ processed_exported_program , export_graph_id = _process_exported_program (
308+ exported_program
309+ )
310+
311+ # Set the extracted data
312+ self .exported_program = processed_exported_program
313+ self .export_graph_id = export_graph_id
314+
315+ def add_edge_dialect_program (
316+ self ,
317+ edge_dialect_program : Union [EdgeProgramManager , ExirExportedProgram ],
318+ ) -> None :
319+ """
320+ Add edge dialect program to the ETRecord after it has been created.
321+
322+ This method allows users to add an edge dialect program they want to record
323+ to an existing ETRecord instance. The edge dialect program will be included
324+ when the ETRecord is saved.
325+
326+ Args:
327+ edge_dialect_program: The edge dialect program for this model returned by the call to
328+ `to_edge()` or `EdgeProgramManager` for this model.
329+
330+ Raises:
331+ RuntimeError: If edge dialect program already exists in the ETRecord.
332+ """
333+ # Check if edge dialect program already exists
334+ if self .edge_dialect_program is not None :
335+ raise RuntimeError (
336+ "Edge dialect program already exists in the ETRecord. "
337+ "Cannot add edge dialect program when it already exists."
338+ )
339+
340+ # Process edge dialect program and extract data
341+ processed_edge_dialect_program = _process_edge_dialect_program (
342+ edge_dialect_program
343+ )
344+
345+ # Set the extracted data
346+ self .edge_dialect_program = processed_edge_dialect_program
347+
203348
204349def _get_reference_outputs (
205350 bundled_program : BundledProgram ,
@@ -285,37 +430,24 @@ def generate_etrecord(
285430 Returns:
286431 None
287432 """
288- # Process all inputs and prepare data for ETRecord construction
289- processed_exported_program , export_graph_id = _process_exported_program (
290- exported_program
291- )
292- graph_map = _process_extra_recorded_modules (extra_recorded_export_modules )
293- processed_edge_dialect_program = _process_edge_dialect_program (edge_dialect_program )
294- debug_handle_map , delegate_map , reference_outputs , representative_inputs = (
295- _process_executorch_program (executorch_program )
296- )
433+ etrecord = ETRecord ()
434+ etrecord .add_exported_program (exported_program )
435+ etrecord .add_edge_dialect_program (edge_dialect_program )
436+ etrecord .add_executorch_program (executorch_program )
297437
298- # Create ETRecord instance and save
299- etrecord = ETRecord (
300- exported_program = processed_exported_program ,
301- export_graph_id = export_graph_id ,
302- edge_dialect_program = processed_edge_dialect_program ,
303- graph_map = graph_map if graph_map else None ,
304- _debug_handle_map = debug_handle_map ,
305- _delegate_map = delegate_map ,
306- _reference_outputs = reference_outputs ,
307- _representative_inputs = representative_inputs ,
308- )
438+ # Add extra export modules if user provided
439+ if extra_recorded_export_modules is not None :
440+ etrecord .add_extra_export_modules (extra_recorded_export_modules )
309441
310442 etrecord .save (et_record )
311443
312444
313445def _process_exported_program (
314446 exported_program : Optional [Union [ExportedProgram , Dict [str , ExportedProgram ]]]
315- ) -> tuple [Optional [ExportedProgram ], int ]:
447+ ) -> tuple [Optional [ExportedProgram ], Optional [ int ] ]:
316448 """Process exported program and return the processed program and export graph id."""
317449 processed_exported_program = None
318- export_graph_id = 0
450+ export_graph_id = None
319451
320452 if exported_program is not None :
321453 if isinstance (exported_program , dict ) and "forward" in exported_program :
@@ -329,29 +461,6 @@ def _process_exported_program(
329461 return processed_exported_program , export_graph_id
330462
331463
332- def _process_extra_recorded_modules (
333- extra_recorded_export_modules : Optional [
334- Dict [
335- str ,
336- Union [
337- ExportedProgram ,
338- ExirExportedProgram ,
339- EdgeProgramManager ,
340- ],
341- ]
342- ]
343- ) -> Dict [str , ExportedProgram ]:
344- """Process extra recorded export modules and return graph map."""
345- graph_map = {}
346-
347- if extra_recorded_export_modules is not None :
348- for module_name , export_module in extra_recorded_export_modules .items ():
349- _validate_module_name (module_name )
350- _add_module_to_graph_map (graph_map , module_name , export_module )
351-
352- return graph_map
353-
354-
355464def _validate_module_name (module_name : str ) -> None :
356465 """Validate that module name is not a reserved name."""
357466 contains_reserved_name = any (
@@ -369,6 +478,8 @@ def _add_module_to_graph_map(
369478 export_module : Union [ExportedProgram , ExirExportedProgram , EdgeProgramManager ],
370479) -> None :
371480 """Add export module to graph map based on its type."""
481+ _validate_module_name (module_name )
482+
372483 if isinstance (export_module , ExirExportedProgram ):
373484 graph_map [f"{ module_name } /forward" ] = export_module .exported_program
374485 elif isinstance (export_module , ExportedProgram ):
0 commit comments