Skip to content

Commit 964a57a

Browse files
authored
bring etrecord updated "reverted" by gh patch fix bot back
Differential Revision: D79599520 Pull Request resolved: #13117
1 parent b92c5b4 commit 964a57a

File tree

3 files changed

+1026
-47
lines changed

3 files changed

+1026
-47
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 156 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,151 @@ def _save_edge_dialect_program(
200200
f"{base_name}_example_inputs", serialized_artifact.example_inputs
201201
)
202202

203+
def add_extra_export_modules(
204+
self,
205+
extra_recorded_export_modules: Dict[
206+
str,
207+
Union[
208+
ExportedProgram,
209+
ExirExportedProgram,
210+
EdgeProgramManager,
211+
],
212+
],
213+
) -> None:
214+
"""
215+
Add extra export modules to the ETRecord after it has been created.
216+
217+
This method allows users to add more export modules they want to record
218+
to an existing ETRecord instance. The modules will be added to the graph_map
219+
and will be included when the ETRecord is saved.
220+
221+
Args:
222+
extra_recorded_export_modules: A dictionary of graph modules with the key being
223+
the user provided name and the value being the corresponding exported module.
224+
The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
225+
"""
226+
if self.graph_map is None:
227+
self.graph_map = {}
228+
229+
# Now self.graph_map is guaranteed to be non-None
230+
graph_map = self.graph_map
231+
for module_name, export_module in extra_recorded_export_modules.items():
232+
_add_module_to_graph_map(graph_map, module_name, export_module)
233+
234+
def add_executorch_program(
235+
self,
236+
executorch_program: Union[
237+
ExecutorchProgram,
238+
ExecutorchProgramManager,
239+
BundledProgram,
240+
],
241+
) -> None:
242+
"""
243+
Add executorch program data to the ETRecord after it has been created.
244+
245+
This method allows users to add executorch program data they want to record
246+
to an existing ETRecord instance. The executorch program data includes debug handle map,
247+
delegate map, reference outputs, and representative inputs that will be included
248+
when the ETRecord is saved.
249+
250+
Args:
251+
executorch_program: The ExecuTorch program for this model returned by the call to
252+
`to_executorch()` or the `BundledProgram` of this model.
253+
254+
Raises:
255+
RuntimeError: If executorch program data already exists in the ETRecord.
256+
"""
257+
# Check if executorch program data already exists
258+
if (
259+
self._debug_handle_map is not None
260+
or self._delegate_map is not None
261+
or self._reference_outputs is not None
262+
or self._representative_inputs is not None
263+
):
264+
raise RuntimeError(
265+
"Executorch program data already exists in the ETRecord. "
266+
"Cannot add executorch program data when it already exists."
267+
)
268+
269+
# Process executorch program and extract data
270+
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
271+
_process_executorch_program(executorch_program)
272+
)
273+
274+
# Set the extracted data
275+
self._debug_handle_map = debug_handle_map
276+
self._delegate_map = delegate_map
277+
self._reference_outputs = reference_outputs
278+
self._representative_inputs = representative_inputs
279+
280+
def add_exported_program(
281+
self,
282+
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]],
283+
) -> None:
284+
"""
285+
Add exported program to the ETRecord after it has been created.
286+
287+
This method allows users to add an exported program they want to record
288+
to an existing ETRecord instance. The exported program will be included
289+
when the ETRecord is saved.
290+
291+
Args:
292+
exported_program: The exported program for this model returned by the call to
293+
`torch.export()` or a dictionary with method names as keys and exported programs as values.
294+
Can be None, in which case no exported program data will be added.
295+
296+
Raises:
297+
RuntimeError: If exported program already exists in the ETRecord.
298+
"""
299+
# Check if exported program already exists
300+
if self.exported_program is not None or self.export_graph_id is not None:
301+
raise RuntimeError(
302+
"Exported program already exists in the ETRecord. "
303+
"Cannot add exported program when it already exists."
304+
)
305+
306+
# Process exported program and extract data
307+
processed_exported_program, export_graph_id = _process_exported_program(
308+
exported_program
309+
)
310+
311+
# Set the extracted data
312+
self.exported_program = processed_exported_program
313+
self.export_graph_id = export_graph_id
314+
315+
def add_edge_dialect_program(
316+
self,
317+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
318+
) -> None:
319+
"""
320+
Add edge dialect program to the ETRecord after it has been created.
321+
322+
This method allows users to add an edge dialect program they want to record
323+
to an existing ETRecord instance. The edge dialect program will be included
324+
when the ETRecord is saved.
325+
326+
Args:
327+
edge_dialect_program: The edge dialect program for this model returned by the call to
328+
`to_edge()` or `EdgeProgramManager` for this model.
329+
330+
Raises:
331+
RuntimeError: If edge dialect program already exists in the ETRecord.
332+
"""
333+
# Check if edge dialect program already exists
334+
if self.edge_dialect_program is not None:
335+
raise RuntimeError(
336+
"Edge dialect program already exists in the ETRecord. "
337+
"Cannot add edge dialect program when it already exists."
338+
)
339+
340+
# Process edge dialect program and extract data
341+
processed_edge_dialect_program = _process_edge_dialect_program(
342+
edge_dialect_program
343+
)
344+
345+
# Set the extracted data
346+
self.edge_dialect_program = processed_edge_dialect_program
347+
203348

204349
def _get_reference_outputs(
205350
bundled_program: BundledProgram,
@@ -285,37 +430,24 @@ def generate_etrecord(
285430
Returns:
286431
None
287432
"""
288-
# Process all inputs and prepare data for ETRecord construction
289-
processed_exported_program, export_graph_id = _process_exported_program(
290-
exported_program
291-
)
292-
graph_map = _process_extra_recorded_modules(extra_recorded_export_modules)
293-
processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program)
294-
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
295-
_process_executorch_program(executorch_program)
296-
)
433+
etrecord = ETRecord()
434+
etrecord.add_exported_program(exported_program)
435+
etrecord.add_edge_dialect_program(edge_dialect_program)
436+
etrecord.add_executorch_program(executorch_program)
297437

298-
# Create ETRecord instance and save
299-
etrecord = ETRecord(
300-
exported_program=processed_exported_program,
301-
export_graph_id=export_graph_id,
302-
edge_dialect_program=processed_edge_dialect_program,
303-
graph_map=graph_map if graph_map else None,
304-
_debug_handle_map=debug_handle_map,
305-
_delegate_map=delegate_map,
306-
_reference_outputs=reference_outputs,
307-
_representative_inputs=representative_inputs,
308-
)
438+
# Add extra export modules if user provided
439+
if extra_recorded_export_modules is not None:
440+
etrecord.add_extra_export_modules(extra_recorded_export_modules)
309441

310442
etrecord.save(et_record)
311443

312444

313445
def _process_exported_program(
314446
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]]
315-
) -> tuple[Optional[ExportedProgram], int]:
447+
) -> tuple[Optional[ExportedProgram], Optional[int]]:
316448
"""Process exported program and return the processed program and export graph id."""
317449
processed_exported_program = None
318-
export_graph_id = 0
450+
export_graph_id = None
319451

320452
if exported_program is not None:
321453
if isinstance(exported_program, dict) and "forward" in exported_program:
@@ -329,29 +461,6 @@ def _process_exported_program(
329461
return processed_exported_program, export_graph_id
330462

331463

332-
def _process_extra_recorded_modules(
333-
extra_recorded_export_modules: Optional[
334-
Dict[
335-
str,
336-
Union[
337-
ExportedProgram,
338-
ExirExportedProgram,
339-
EdgeProgramManager,
340-
],
341-
]
342-
]
343-
) -> Dict[str, ExportedProgram]:
344-
"""Process extra recorded export modules and return graph map."""
345-
graph_map = {}
346-
347-
if extra_recorded_export_modules is not None:
348-
for module_name, export_module in extra_recorded_export_modules.items():
349-
_validate_module_name(module_name)
350-
_add_module_to_graph_map(graph_map, module_name, export_module)
351-
352-
return graph_map
353-
354-
355464
def _validate_module_name(module_name: str) -> None:
356465
"""Validate that module name is not a reserved name."""
357466
contains_reserved_name = any(
@@ -369,6 +478,8 @@ def _add_module_to_graph_map(
369478
export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager],
370479
) -> None:
371480
"""Add export module to graph map based on its type."""
481+
_validate_module_name(module_name)
482+
372483
if isinstance(export_module, ExirExportedProgram):
373484
graph_map[f"{module_name}/forward"] = export_module.exported_program
374485
elif isinstance(export_module, ExportedProgram):

0 commit comments

Comments
 (0)