Skip to content

Commit 4650a00

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
bring etrecord updated "reverted" by gh patch fix bot back
Summary: D79279401 D79336982 and D79294945 was landed last week but got "reverted" by gh patch fix on Saturday D78689027 due to didn't merge gh PR on time. This diff brings the updates back. Differential Revision: D79599520
1 parent 07b6059 commit 4650a00

File tree

3 files changed

+1026
-47
lines changed

3 files changed

+1026
-47
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 156 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,151 @@ def _save_edge_dialect_program(
200200
f"{base_name}_example_inputs", serialized_artifact.example_inputs
201201
)
202202

203+
def add_extra_export_modules(
204+
self,
205+
extra_recorded_export_modules: Dict[
206+
str,
207+
Union[
208+
ExportedProgram,
209+
ExirExportedProgram,
210+
EdgeProgramManager,
211+
],
212+
],
213+
) -> None:
214+
"""
215+
Add extra export modules to the ETRecord after it has been created.
216+
217+
This method allows users to add more export modules they want to record
218+
to an existing ETRecord instance. The modules will be added to the graph_map
219+
and will be included when the ETRecord is saved.
220+
221+
Args:
222+
extra_recorded_export_modules: A dictionary of graph modules with the key being
223+
the user provided name and the value being the corresponding exported module.
224+
The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
225+
"""
226+
if self.graph_map is None:
227+
self.graph_map = {}
228+
229+
# Now self.graph_map is guaranteed to be non-None
230+
graph_map = self.graph_map
231+
for module_name, export_module in extra_recorded_export_modules.items():
232+
_add_module_to_graph_map(graph_map, module_name, export_module)
233+
234+
def add_executorch_program(
235+
self,
236+
executorch_program: Union[
237+
ExecutorchProgram,
238+
ExecutorchProgramManager,
239+
BundledProgram,
240+
],
241+
) -> None:
242+
"""
243+
Add executorch program data to the ETRecord after it has been created.
244+
245+
This method allows users to add executorch program data they want to record
246+
to an existing ETRecord instance. The executorch program data includes debug handle map,
247+
delegate map, reference outputs, and representative inputs that will be included
248+
when the ETRecord is saved.
249+
250+
Args:
251+
executorch_program: The ExecuTorch program for this model returned by the call to
252+
`to_executorch()` or the `BundledProgram` of this model.
253+
254+
Raises:
255+
RuntimeError: If executorch program data already exists in the ETRecord.
256+
"""
257+
# Check if executorch program data already exists
258+
if (
259+
self._debug_handle_map is not None
260+
or self._delegate_map is not None
261+
or self._reference_outputs is not None
262+
or self._representative_inputs is not None
263+
):
264+
raise RuntimeError(
265+
"Executorch program data already exists in the ETRecord. "
266+
"Cannot add executorch program data when it already exists."
267+
)
268+
269+
# Process executorch program and extract data
270+
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
271+
_process_executorch_program(executorch_program)
272+
)
273+
274+
# Set the extracted data
275+
self._debug_handle_map = debug_handle_map
276+
self._delegate_map = delegate_map
277+
self._reference_outputs = reference_outputs
278+
self._representative_inputs = representative_inputs
279+
280+
def add_exported_program(
281+
self,
282+
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]],
283+
) -> None:
284+
"""
285+
Add exported program to the ETRecord after it has been created.
286+
287+
This method allows users to add an exported program they want to record
288+
to an existing ETRecord instance. The exported program will be included
289+
when the ETRecord is saved.
290+
291+
Args:
292+
exported_program: The exported program for this model returned by the call to
293+
`torch.export()` or a dictionary with method names as keys and exported programs as values.
294+
Can be None, in which case no exported program data will be added.
295+
296+
Raises:
297+
RuntimeError: If exported program already exists in the ETRecord.
298+
"""
299+
# Check if exported program already exists
300+
if self.exported_program is not None or self.export_graph_id is not None:
301+
raise RuntimeError(
302+
"Exported program already exists in the ETRecord. "
303+
"Cannot add exported program when it already exists."
304+
)
305+
306+
# Process exported program and extract data
307+
processed_exported_program, export_graph_id = _process_exported_program(
308+
exported_program
309+
)
310+
311+
# Set the extracted data
312+
self.exported_program = processed_exported_program
313+
self.export_graph_id = export_graph_id
314+
315+
def add_edge_dialect_program(
316+
self,
317+
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
318+
) -> None:
319+
"""
320+
Add edge dialect program to the ETRecord after it has been created.
321+
322+
This method allows users to add an edge dialect program they want to record
323+
to an existing ETRecord instance. The edge dialect program will be included
324+
when the ETRecord is saved.
325+
326+
Args:
327+
edge_dialect_program: The edge dialect program for this model returned by the call to
328+
`to_edge()` or `EdgeProgramManager` for this model.
329+
330+
Raises:
331+
RuntimeError: If edge dialect program already exists in the ETRecord.
332+
"""
333+
# Check if edge dialect program already exists
334+
if self.edge_dialect_program is not None:
335+
raise RuntimeError(
336+
"Edge dialect program already exists in the ETRecord. "
337+
"Cannot add edge dialect program when it already exists."
338+
)
339+
340+
# Process edge dialect program and extract data
341+
processed_edge_dialect_program = _process_edge_dialect_program(
342+
edge_dialect_program
343+
)
344+
345+
# Set the extracted data
346+
self.edge_dialect_program = processed_edge_dialect_program
347+
203348

204349
def _get_reference_outputs(
205350
bundled_program: BundledProgram,
@@ -285,37 +430,24 @@ def generate_etrecord(
285430
Returns:
286431
None
287432
"""
288-
# Process all inputs and prepare data for ETRecord construction
289-
processed_exported_program, export_graph_id = _process_exported_program(
290-
exported_program
291-
)
292-
graph_map = _process_extra_recorded_modules(extra_recorded_export_modules)
293-
processed_edge_dialect_program = _process_edge_dialect_program(edge_dialect_program)
294-
debug_handle_map, delegate_map, reference_outputs, representative_inputs = (
295-
_process_executorch_program(executorch_program)
296-
)
433+
etrecord = ETRecord()
434+
etrecord.add_exported_program(exported_program)
435+
etrecord.add_edge_dialect_program(edge_dialect_program)
436+
etrecord.add_executorch_program(executorch_program)
297437

298-
# Create ETRecord instance and save
299-
etrecord = ETRecord(
300-
exported_program=processed_exported_program,
301-
export_graph_id=export_graph_id,
302-
edge_dialect_program=processed_edge_dialect_program,
303-
graph_map=graph_map if graph_map else None,
304-
_debug_handle_map=debug_handle_map,
305-
_delegate_map=delegate_map,
306-
_reference_outputs=reference_outputs,
307-
_representative_inputs=representative_inputs,
308-
)
438+
# Add extra export modules if user provided
439+
if extra_recorded_export_modules is not None:
440+
etrecord.add_extra_export_modules(extra_recorded_export_modules)
309441

310442
etrecord.save(et_record)
311443

312444

313445
def _process_exported_program(
314446
exported_program: Optional[Union[ExportedProgram, Dict[str, ExportedProgram]]]
315-
) -> tuple[Optional[ExportedProgram], int]:
447+
) -> tuple[Optional[ExportedProgram], Optional[int]]:
316448
"""Process exported program and return the processed program and export graph id."""
317449
processed_exported_program = None
318-
export_graph_id = 0
450+
export_graph_id = None
319451

320452
if exported_program is not None:
321453
if isinstance(exported_program, dict) and "forward" in exported_program:
@@ -329,29 +461,6 @@ def _process_exported_program(
329461
return processed_exported_program, export_graph_id
330462

331463

332-
def _process_extra_recorded_modules(
333-
extra_recorded_export_modules: Optional[
334-
Dict[
335-
str,
336-
Union[
337-
ExportedProgram,
338-
ExirExportedProgram,
339-
EdgeProgramManager,
340-
],
341-
]
342-
]
343-
) -> Dict[str, ExportedProgram]:
344-
"""Process extra recorded export modules and return graph map."""
345-
graph_map = {}
346-
347-
if extra_recorded_export_modules is not None:
348-
for module_name, export_module in extra_recorded_export_modules.items():
349-
_validate_module_name(module_name)
350-
_add_module_to_graph_map(graph_map, module_name, export_module)
351-
352-
return graph_map
353-
354-
355464
def _validate_module_name(module_name: str) -> None:
356465
"""Validate that module name is not a reserved name."""
357466
contains_reserved_name = any(
@@ -369,6 +478,8 @@ def _add_module_to_graph_map(
369478
export_module: Union[ExportedProgram, ExirExportedProgram, EdgeProgramManager],
370479
) -> None:
371480
"""Add export module to graph map based on its type."""
481+
_validate_module_name(module_name)
482+
372483
if isinstance(export_module, ExirExportedProgram):
373484
graph_map[f"{module_name}/forward"] = export_module.exported_program
374485
elif isinstance(export_module, ExportedProgram):

0 commit comments

Comments
 (0)