Skip to content

Commit b4fc97c

Browse files
committed
add executorch program equipment support in etrecord class
Previously we have to provide all essentail infos at the same time to generate etrecord; however if we want to generate it through export flow we can not find a stage that having all essential infos so that we need to have a new way to contruct it on-the-fly. This diff makes the target happen by adding three functions: `add_exported_program`, `add_edge_dialect_program` and `add_executorch_program` so that whenever we have the required info we can equip it into etrecord. Also update test case for test coverage. Differential Revision: [D79294945](https://our.internmc.facebook.com/intern/diff/D79294945/) [ghstack-poisoned]
1 parent a8114d3 commit b4fc97c

File tree

2 files changed

+761
-46
lines changed

2 files changed

+761
-46
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 125 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,122 @@ def add_extra_export_modules(
229229
# Now self.graph_map is guaranteed to be non-None
230230
graph_map = self.graph_map
231231
for module_name, export_module in extra_recorded_export_modules.items():
232-
_validate_module_name(module_name)
233232
_add_module_to_graph_map(graph_map, module_name, export_module)
234233

234+
def add_executorch_program(
235+
self,
236+
executorch_program: Union[
237+
ExecutorchProgram,
238+
ExecutorchProgramManager,
239+
BundledProgram,
240+
],
241+
) -> None:
242+
"""
243+
Add executorch program data to the ETRecord after it has been created.
244+
245+
This method allows users to add executorch program data they want to record
246+
to an existing ETRecord instance. The executorch program data includes debug handle map,
247+
delegate map, reference outputs, and representative inputs that will be included
248+
when the ETRecord is saved.
249+
250+
Args:
251+
executorch_program: The ExecuTorch program for this model returned by the call to
252+
`to_executorch()` or the `BundledProgram` of this model.
253+
254+
Raises:
255+
RuntimeError: If executorch program data already exists in the ETRecord.
256+
"""
257+
# Check if executorch program data already exists
258+
if (
259+
self._debug_handle_map is not None
260+
or self._delegate_map is not None
261+
or self._reference_outputs is not None
262+
or self._representative_inputs is not None
263+
):
264+
raise RuntimeError(
265+
"Executorch program data already exists in the ETRecord. "
266+
"Cannot add executorch program data when it already exists."
267+
)
268+
269+
# Process executorch program and extract data
270+
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
271+
_process_executorch_program(executorch_program)
272+
)
273+
274+
# Set the extracted data
275+
self._debug_handle_map = debug_handle_map
276+
self._delegate_map = delegate_map
277+
self._reference_outputs = reference_outputs
278+
self._representative_inputs = representative_inputs
279+
280+
def add_exported_program(
281+
self,
282+
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]],
283+
) -> None:
284+
"""
285+
Add exported program to the ETRecord after it has been created.
286+
287+
This method allows users to add an exported program they want to record
288+
to an existing ETRecord instance. The exported program will be included
289+
when the ETRecord is saved.
290+
291+
Args:
292+
exported_program: The exported program for this model returned by the call to
293+
`torch.export()` or a dictionary with method names as keys and exported programs as values.
294+
Can be None, in which case no exported program data will be added.
295+
296+
Raises:
297+
RuntimeError: If exported program already exists in the ETRecord.
298+
"""
299+
# Check if exported program already exists
300+
if self.exported_program is not None or self.export_graph_id is not None:
301+
raise RuntimeError(
302+
"Exported program already exists in the ETRecord. "
303+
"Cannot add exported program when it already exists."
304+
)
305+
306+
# Process exported program and extract data
307+
processed_exported_program, export_graph_id = _process_exported_program(
308+
exported_program
309+
)
310+
311+
# Set the extracted data
312+
self.exported_program = processed_exported_program
313+
self.export_graph_id = export_graph_id
314+
315+
def add_edge_dialect_program(
316+
self,
317+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
318+
) -> None:
319+
"""
320+
Add edge dialect program to the ETRecord after it has been created.
321+
322+
This method allows users to add an edge dialect program they want to record
323+
to an existing ETRecord instance. The edge dialect program will be included
324+
when the ETRecord is saved.
325+
326+
Args:
327+
edge_dialect_program: The edge dialect program for this model returned by the call to
328+
`to_edge()` or `EdgeProgramManager` for this model.
329+
330+
Raises:
331+
RuntimeError: If edge dialect program already exists in the ETRecord.
332+
"""
333+
# Check if edge dialect program already exists
334+
if self.edge_dialect_program is not None:
335+
raise RuntimeError(
336+
"Edge dialect program already exists in the ETRecord. "
337+
"Cannot add edge dialect program when it already exists."
338+
)
339+
340+
# Process edge dialect program and extract data
341+
processed_edge_dialect_program = _process_edge_dialect_program(
342+
edge_dialect_program
343+
)
344+
345+
# Set the extracted data
346+
self.edge_dialect_program = processed_edge_dialect_program
347+
235348

236349
def _get_reference_outputs(
237350
bundled_program: BundledProgram,
@@ -317,37 +430,24 @@ def generate_etrecord(
317430
Returns:
318431
None
319432
"""
320-
# Process all inputs and prepare data for ETRecord construction
321-
processed_exported_program, export_graph_id = _process_exported_program(
322-
exported_program
323-
)
324-
graph_map = _process_extra_recorded_modules(extra_recorded_export_modules)
325-
processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program)
326-
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
327-
_process_executorch_program(executorch_program)
328-
)
433+
etrecord = ETRecord()
434+
etrecord.add_exported_program(exported_program)
435+
etrecord.add_edge_dialect_program(edge_dialect_program)
436+
etrecord.add_executorch_program(executorch_program)
329437

330-
# Create ETRecord instance and save
331-
etrecord = ETRecord(
332-
exported_program=processed_exported_program,
333-
export_graph_id=export_graph_id,
334-
edge_dialect_program=processed_edge_dialect_program,
335-
graph_map=graph_map if graph_map else None,
336-
_debug_handle_map=debug_handle_map,
337-
_delegate_map=delegate_map,
338-
_reference_outputs=reference_outputs,
339-
_representative_inputs=representative_inputs,
340-
)
438+
# Add extra export modules if user provided
439+
if extra_recorded_export_modules is not None:
440+
etrecord.add_extra_export_modules(extra_recorded_export_modules)
341441

342442
etrecord.save(et_record)
343443

344444

345445
def _process_exported_program(
346446
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]]
347-
) -> tuple[Optional[ExportedProgram], int]:
447+
) -> tuple[Optional[ExportedProgram], Optional[int]]:
348448
"""Process exported program and return the processed program and export graph id."""
349449
processed_exported_program = None
350-
export_graph_id = 0
450+
export_graph_id = None
351451

352452
if exported_program is not None:
353453
if isinstance(exported_program, dict) and "forward" in exported_program:
@@ -361,29 +461,6 @@ def _process_exported_program(
361461
return processed_exported_program, export_graph_id
362462

363463

364-
def _process_extra_recorded_modules(
365-
extra_recorded_export_modules: Optional[
366-
Dict[
367-
str,
368-
Union[
369-
ExportedProgram,
370-
ExirExportedProgram,
371-
EdgeProgramManager,
372-
],
373-
]
374-
]
375-
) -> Dict[str, ExportedProgram]:
376-
"""Process extra recorded export modules and return graph map."""
377-
graph_map = {}
378-
379-
if extra_recorded_export_modules is not None:
380-
for module_name, export_module in extra_recorded_export_modules.items():
381-
_validate_module_name(module_name)
382-
_add_module_to_graph_map(graph_map, module_name, export_module)
383-
384-
return graph_map
385-
386-
387464
def _validate_module_name(module_name: str) -> None:
388465
"""Validate that module name is not a reserved name."""
389466
contains_reserved_name = any(
@@ -401,6 +478,8 @@ def _add_module_to_graph_map(
401478
export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager],
402479
) -> None:
403480
"""Add export module to graph map based on its type."""
481+
_validate_module_name(module_name)
482+
404483
if isinstance(export_module, ExirExportedProgram):
405484
graph_map[f"{module_name}/forward"] = export_module.exported_program
406485
elif isinstance(export_module, ExportedProgram):

0 commit comments

Comments
 (0)