Skip to content

Commit f4fe072

Browse files
committed
add executorch program equipment support in etrecord class
Pull Request resolved: #13020 Previously we have to provide all essentail infos at the same time to generate etrecord; however if we want to generate it through export flow we can not find a stage that having all essential infos so that we need to have a new way to contruct it on-the-fly. This diff makes the target happen by adding three functions: `add_exported_program`, `add_edge_dialect_program` and `add_executorch_program` so that whenever we have the required info we can equip it into etrecord. Also update test case for test coverage. ghstack-source-id: 300161202 Differential Revision: [D79294945](https://our.internmc.facebook.com/intern/diff/D79294945/)
1 parent 2202e7f commit f4fe072

File tree

2 files changed

+715
-46
lines changed

2 files changed

+715
-46
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 125 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,122 @@ def add_extra_export_modules(
229229
# Now self.graph_map is guaranteed to be non-None
230230
graph_map = self.graph_map
231231
for module_name, export_module in extra_recorded_export_modules.items():
232-
_validate_module_name(module_name)
233232
_add_module_to_graph_map(graph_map, module_name, export_module)
234233

234+
def add_executorch_program(
235+
self,
236+
executorch_program: Union[
237+
ExecutorchProgram,
238+
ExecutorchProgramManager,
239+
BundledProgram,
240+
],
241+
) -> None:
242+
"""
243+
Add executorch program data to the ETRecord after it has been created.
244+
245+
This method allows users to add executorch program data they want to record
246+
to an existing ETRecord instance. The executorch program data includes debug handle map,
247+
delegate map, reference outputs, and representative inputs that will be included
248+
when the ETRecord is saved.
249+
250+
Args:
251+
executorch_program: The ExecuTorch program for this model returned by the call to
252+
`to_executorch()` or the `BundledProgram` of this model.
253+
254+
Raises:
255+
RuntimeError: If executorch program data already exists in the ETRecord.
256+
"""
257+
# Check if executorch program data already exists
258+
if (
259+
self._debug_handle_map is not None
260+
or self._delegate_map is not None
261+
or self._reference_outputs is not None
262+
or self._representative_inputs is not None
263+
):
264+
raise RuntimeError(
265+
"Executorch program data already exists in the ETRecord. "
266+
"Cannot add executorch program data when it already exists."
267+
)
268+
269+
# Process executorch program and extract data
270+
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
271+
_process_executorch_program(executorch_program)
272+
)
273+
274+
# Set the extracted data
275+
self._debug_handle_map = debug_handle_map
276+
self._delegate_map = delegate_map
277+
self._reference_outputs = reference_outputs
278+
self._representative_inputs = representative_inputs
279+
280+
def add_exported_program(
281+
self,
282+
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]],
283+
) -> None:
284+
"""
285+
Add exported program to the ETRecord after it has been created.
286+
287+
This method allows users to add an exported program they want to record
288+
to an existing ETRecord instance. The exported program will be included
289+
when the ETRecord is saved.
290+
291+
Args:
292+
exported_program: The exported program for this model returned by the call to
293+
`torch.export()` or a dictionary with method names as keys and exported programs as values.
294+
Can be None, in which case no exported program data will be added.
295+
296+
Raises:
297+
RuntimeError: If exported program already exists in the ETRecord.
298+
"""
299+
# Check if exported program already exists
300+
if self.exported_program is not None or self.export_graph_id is not None:
301+
raise RuntimeError(
302+
"Exported program already exists in the ETRecord. "
303+
"Cannot add exported program when it already exists."
304+
)
305+
306+
# Process exported program and extract data
307+
processed_exported_program, export_graph_id = _process_exported_program(
308+
exported_program
309+
)
310+
311+
# Set the extracted data
312+
self.exported_program = processed_exported_program
313+
self.export_graph_id = export_graph_id
314+
315+
def add_edge_dialect_program(
316+
self,
317+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
318+
) -> None:
319+
"""
320+
Add edge dialect program to the ETRecord after it has been created.
321+
322+
This method allows users to add an edge dialect program they want to record
323+
to an existing ETRecord instance. The edge dialect program will be included
324+
when the ETRecord is saved.
325+
326+
Args:
327+
edge_dialect_program: The edge dialect program for this model returned by the call to
328+
`to_edge()` or `EdgeProgramManager` for this model.
329+
330+
Raises:
331+
RuntimeError: If edge dialect program already exists in the ETRecord.
332+
"""
333+
# Check if edge dialect program already exists
334+
if self.edge_dialect_program is not None:
335+
raise RuntimeError(
336+
"Edge dialect program already exists in the ETRecord. "
337+
"Cannot add edge dialect program when it already exists."
338+
)
339+
340+
# Process edge dialect program and extract data
341+
processed_edge_dialect_program = _process_edge_dialect_program(
342+
edge_dialect_program
343+
)
344+
345+
# Set the extracted data
346+
self.edge_dialect_program = processed_edge_dialect_program
347+
235348

236349
def _get_reference_outputs(
237350
bundled_program: BundledProgram,
@@ -317,37 +430,24 @@ def generate_etrecord(
317430
Returns:
318431
None
319432
"""
320-
# Process all inputs and prepare data for ETRecord construction
321-
processed_exported_program, export_graph_id = _process_exported_program(
322-
exported_program
323-
)
324-
graph_map = _process_extra_recorded_modules(extra_recorded_export_modules)
325-
processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program)
326-
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
327-
_process_executorch_program(executorch_program)
328-
)
433+
etrecord = ETRecord()
434+
etrecord.add_exported_program(exported_program)
435+
etrecord.add_edge_dialect_program(edge_dialect_program)
436+
etrecord.add_executorch_program(executorch_program)
329437

330-
# Create ETRecord instance and save
331-
etrecord = ETRecord(
332-
exported_program=processed_exported_program,
333-
export_graph_id=export_graph_id,
334-
edge_dialect_program=processed_edge_dialect_program,
335-
graph_map=graph_map if graph_map else None,
336-
_debug_handle_map=debug_handle_map,
337-
_delegate_map=delegate_map,
338-
_reference_outputs=reference_outputs,
339-
_representative_inputs=representative_inputs,
340-
)
438+
# Add extra export modules if user provided
439+
if extra_recorded_export_modules is not None:
440+
etrecord.add_extra_export_modules(extra_recorded_export_modules)
341441

342442
etrecord.save(et_record)
343443

344444

345445
def _process_exported_program(
346446
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]]
347-
) -> tuple[Optional[ExportedProgram], int]:
447+
) -> tuple[Optional[ExportedProgram], Optional[int]]:
348448
"""Process exported program and return the processed program and export graph id."""
349449
processed_exported_program = None
350-
export_graph_id = 0
450+
export_graph_id = None
351451

352452
if exported_program is not None:
353453
if isinstance(exported_program, dict) and "forward" in exported_program:
@@ -361,29 +461,6 @@ def _process_exported_program(
361461
return processed_exported_program, export_graph_id
362462

363463

364-
def _process_extra_recorded_modules(
365-
extra_recorded_export_modules: Optional[
366-
Dict[
367-
str,
368-
Union[
369-
ExportedProgram,
370-
ExirExportedProgram,
371-
EdgeProgramManager,
372-
],
373-
]
374-
]
375-
) -> Dict[str, ExportedProgram]:
376-
"""Process extra recorded export modules and return graph map."""
377-
graph_map = {}
378-
379-
if extra_recorded_export_modules is not None:
380-
for module_name, export_module in extra_recorded_export_modules.items():
381-
_validate_module_name(module_name)
382-
_add_module_to_graph_map(graph_map, module_name, export_module)
383-
384-
return graph_map
385-
386-
387464
def _validate_module_name(module_name: str) -> None:
388465
"""Validate that module name is not a reserved name."""
389466
contains_reserved_name = any(
@@ -401,6 +478,8 @@ def _add_module_to_graph_map(
401478
export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager],
402479
) -> None:
403480
"""Add export module to graph map based on its type."""
481+
_validate_module_name(module_name)
482+
404483
if isinstance(export_module, ExirExportedProgram):
405484
graph_map[f"{module_name}/forward"] = export_module.exported_program
406485
elif isinstance(export_module, ExportedProgram):

0 commit comments

Comments
 (0)