Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 172 additions & 144 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,96 +55,137 @@ class ETRecordReservedFileNames(StrEnum):
REPRESENTATIVE_INPUTS = "representative_inputs"


@dataclass
class ETRecord:
exported_program: Optional[ExportedProgram] = None
export_graph_id: Optional[int] = None
edge_dialect_program: Optional[ExportedProgram] = None
graph_map: Optional[Dict[str, ExportedProgram]] = None
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
_delegate_map: Optional[
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
] = None
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None
_representative_inputs: Optional[List[ProgramOutput]] = None


def _handle_exported_program(
etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
) -> None:
assert isinstance(ep, ExportedProgram)
serialized_artifact = serialize(ep)
assert isinstance(serialized_artifact.exported_program, bytes)
def __init__(
self,
exported_program: Optional[ExportedProgram] = None,
export_graph_id: Optional[int] = None,
edge_dialect_program: Optional[ExportedProgram] = None,
graph_map: Optional[Dict[str, ExportedProgram]] = None,
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None,
_delegate_map: Optional[
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
] = None,
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
_representative_inputs: Optional[List[ProgramOutput]] = None,
):
self.exported_program = exported_program
self.export_graph_id = export_graph_id
self.edge_dialect_program = edge_dialect_program
self.graph_map = graph_map
self._debug_handle_map = _debug_handle_map
self._delegate_map = _delegate_map
self._reference_outputs = _reference_outputs
self._representative_inputs = _representative_inputs

def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
"""
Serialize and save the ETRecord to the specified path.

Args:
path: Path where the ETRecord file will be saved to.
"""
if isinstance(path, (str, os.PathLike)):
path = os.fspath(path)

etrecord_zip = ZipFile(path, "w")

try:
# Write the magic file identifier
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")

# Save exported program if present
if self.exported_program is not None:
self._save_exported_program(
etrecord_zip,
ETRecordReservedFileNames.EXPORTED_PROGRAM,
"",
self.exported_program,
)

method_name = f"/{method_name}" if method_name != "" else ""
# Save edge dialect program if present
if self.edge_dialect_program is not None:
self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program)

# Save graph map if present
if self.graph_map is not None:
for module_name, export_module in self.graph_map.items():
# Extract method name from module_name if it contains "/"
if "/" in module_name:
base_name, method_name = module_name.rsplit("/", 1)
self._save_exported_program(
etrecord_zip, base_name, method_name, export_module
)
else:
self._save_exported_program(
etrecord_zip, module_name, "forward", export_module
)

# Save debug handle map
if self._debug_handle_map is not None:
etrecord_zip.writestr(
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
json.dumps(self._debug_handle_map),
)

etrecord_zip.writestr(
f"{module_name}{method_name}", serialized_artifact.exported_program
)
etrecord_zip.writestr(
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
)
etrecord_zip.writestr(
f"{module_name}{method_name}_constants", serialized_artifact.constants
)
etrecord_zip.writestr(
f"{module_name}{method_name}_example_inputs",
serialized_artifact.example_inputs,
)
# Save delegate map
if self._delegate_map is not None:
etrecord_zip.writestr(
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
json.dumps(self._delegate_map),
)

# Save reference outputs
if self._reference_outputs is not None:
etrecord_zip.writestr(
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
pickle.dumps(self._reference_outputs),
)

def _handle_export_module(
etrecord_zip: ZipFile,
export_module: Union[
ExirExportedProgram,
EdgeProgramManager,
ExportedProgram,
],
module_name: str,
) -> None:
if isinstance(export_module, ExirExportedProgram):
_handle_exported_program(
etrecord_zip, module_name, "forward", export_module.exported_program
)
elif isinstance(export_module, ExportedProgram):
_handle_exported_program(etrecord_zip, module_name, "forward", export_module)
elif isinstance(
export_module,
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
):
for method in export_module.methods:
_handle_exported_program(
etrecord_zip,
module_name,
method,
export_module.exported_program(method),
)
else:
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
# Save representative inputs
if self._representative_inputs is not None:
etrecord_zip.writestr(
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
pickle.dumps(self._representative_inputs),
)

# Save export graph id
if self.export_graph_id is not None:
etrecord_zip.writestr(
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
json.dumps(self.export_graph_id),
)

def _handle_edge_dialect_exported_program(
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
) -> None:
serialized_artifact = serialize(edge_dialect_exported_program)
assert isinstance(serialized_artifact.exported_program, bytes)
finally:
etrecord_zip.close()

etrecord_zip.writestr(
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM,
serialized_artifact.exported_program,
)
etrecord_zip.writestr(
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict",
serialized_artifact.state_dict,
)
etrecord_zip.writestr(
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants",
serialized_artifact.constants,
)
etrecord_zip.writestr(
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs",
serialized_artifact.example_inputs,
)
def _save_exported_program(
self, etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
) -> None:
"""Save an exported program to the ETRecord zip file."""
serialized_artifact = serialize(ep)
assert isinstance(serialized_artifact.exported_program, bytes)

method_name = f"/{method_name}" if method_name != "" else ""
base_name = f"{module_name}{method_name}"

etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs)

def _save_edge_dialect_program(
self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram
) -> None:
"""Save the edge dialect program to the ETRecord zip file."""
serialized_artifact = serialize(edge_dialect_program)
assert isinstance(serialized_artifact.exported_program, bytes)

base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs)


def _get_reference_outputs(
Expand Down Expand Up @@ -231,32 +272,27 @@ def generate_etrecord(
Returns:
None
"""

if isinstance(et_record, (str, os.PathLike)):
et_record = os.fspath(et_record) # pyre-ignore

etrecord_zip = ZipFile(et_record, "w")
# Write the magic file identifier that will be used to verify that this file
# is an etrecord when it's used later in the Developer Tools.
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")

# Calculate export_graph_id before modifying exported_program
# Prepare data for ETRecord construction
processed_exported_program = None
export_graph_id = 0
processed_edge_dialect_program = None
graph_map = {}
debug_handle_map = None
delegate_map = None
reference_outputs = None
representative_inputs = None

# Process exported program
if exported_program is not None:
# If multiple exported programs are provided, only save forward method
if isinstance(exported_program, dict) and "forward" in exported_program:
exported_program = exported_program["forward"]

if isinstance(exported_program, ExportedProgram):
export_graph_id = id(exported_program.graph)
_handle_exported_program(
etrecord_zip,
ETRecordReservedFileNames.EXPORTED_PROGRAM,
"",
exported_program,
)
processed_exported_program = exported_program["forward"]
elif isinstance(exported_program, ExportedProgram):
processed_exported_program = exported_program

if processed_exported_program is not None:
export_graph_id = id(processed_exported_program.graph)

# Process extra recorded export modules
if extra_recorded_export_modules is not None:
for module_name, export_module in extra_recorded_export_modules.items():
contains_reserved_name = any(
Expand All @@ -267,57 +303,49 @@ def generate_etrecord(
raise RuntimeError(
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
)
_handle_export_module(etrecord_zip, export_module, module_name)

if isinstance(
edge_dialect_program,
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
):
_handle_edge_dialect_exported_program(
etrecord_zip,
edge_dialect_program.exported_program(),
)
# Process different types of export modules
if isinstance(export_module, ExirExportedProgram):
graph_map[f"{module_name}/forward"] = export_module.exported_program
elif isinstance(export_module, ExportedProgram):
graph_map[f"{module_name}/forward"] = export_module
elif isinstance(export_module, (EdgeProgramManager, exir.program._program.EdgeProgramManager)):
for method in export_module.methods:
graph_map[f"{module_name}/{method}"] = export_module.exported_program(method)
else:
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")

# Process edge dialect program
if isinstance(edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager)):
processed_edge_dialect_program = edge_dialect_program.exported_program()
elif isinstance(edge_dialect_program, ExirExportedProgram):
_handle_edge_dialect_exported_program(
etrecord_zip,
edge_dialect_program.exported_program,
)
processed_edge_dialect_program = edge_dialect_program.exported_program
else:
raise RuntimeError(
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
)
raise RuntimeError(f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}.")

# When a BundledProgram is passed in, extract the reference outputs and save in a file
# Process executorch program
if isinstance(executorch_program, BundledProgram):
reference_outputs = _get_reference_outputs(executorch_program)
etrecord_zip.writestr(
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
# @lint-ignore PYTHONPICKLEISBAD
pickle.dumps(reference_outputs),
)

representative_inputs = _get_representative_inputs(executorch_program)
etrecord_zip.writestr(
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
# @lint-ignore PYTHONPICKLEISBAD
pickle.dumps(representative_inputs),
)
executorch_program = executorch_program.executorch_program

etrecord_zip.writestr(
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
json.dumps(executorch_program.debug_handle_map),
)
debug_handle_map = executorch_program.executorch_program.debug_handle_map
delegate_map = executorch_program.executorch_program.delegate_map
else:
debug_handle_map = executorch_program.debug_handle_map
delegate_map = executorch_program.delegate_map

etrecord_zip.writestr(
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
json.dumps(executorch_program.delegate_map),
# Create ETRecord instance and save
etrecord = ETRecord(
exported_program=processed_exported_program,
export_graph_id=export_graph_id,
edge_dialect_program=processed_edge_dialect_program,
graph_map=graph_map if graph_map else None,
_debug_handle_map=debug_handle_map,
_delegate_map=delegate_map,
_reference_outputs=reference_outputs,
_representative_inputs=representative_inputs,
)

etrecord_zip.writestr(
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
json.dumps(export_graph_id),
)
etrecord.save(et_record)


def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
Expand Down
Loading
Loading