Skip to content

Commit f035b43

Browse files
committed
equip etrecord class with save method
Differential Revision: [D79205242](https://our.internmc.facebook.com/intern/diff/D79205242/) [ghstack-poisoned]
1 parent 0c8879e commit f035b43

File tree

2 files changed

+287
-144
lines changed

2 files changed

+287
-144
lines changed

devtools/etrecord/_etrecord.py

Lines changed: 172 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -55,96 +55,137 @@ class ETRecordReservedFileNames(StrEnum):
5555
REPRESENTATIVE_INPUTS = "representative_inputs"
5656

5757

58-
@dataclass
5958
class ETRecord:
60-
exported_program: Optional[ExportedProgram] = None
61-
export_graph_id: Optional[int] = None
62-
edge_dialect_program: Optional[ExportedProgram] = None
63-
graph_map: Optional[Dict[str, ExportedProgram]] = None
64-
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None
65-
_delegate_map: Optional[
66-
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
67-
] = None
68-
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None
69-
_representative_inputs: Optional[List[ProgramOutput]] = None
70-
71-
72-
def _handle_exported_program(
73-
etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
74-
) -> None:
75-
assert isinstance(ep, ExportedProgram)
76-
serialized_artifact = serialize(ep)
77-
assert isinstance(serialized_artifact.exported_program, bytes)
59+
def __init__(
60+
self,
61+
exported_program: Optional[ExportedProgram] = None,
62+
export_graph_id: Optional[int] = None,
63+
edge_dialect_program: Optional[ExportedProgram] = None,
64+
graph_map: Optional[Dict[str, ExportedProgram]] = None,
65+
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None,
66+
_delegate_map: Optional[
67+
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
68+
] = None,
69+
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
70+
_representative_inputs: Optional[List[ProgramOutput]] = None,
71+
):
72+
self.exported_program = exported_program
73+
self.export_graph_id = export_graph_id
74+
self.edge_dialect_program = edge_dialect_program
75+
self.graph_map = graph_map
76+
self._debug_handle_map = _debug_handle_map
77+
self._delegate_map = _delegate_map
78+
self._reference_outputs = _reference_outputs
79+
self._representative_inputs = _representative_inputs
80+
81+
def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
82+
"""
83+
Serialize and save the ETRecord to the specified path.
84+
85+
Args:
86+
path: Path where the ETRecord file will be saved to.
87+
"""
88+
if isinstance(path, (str, os.PathLike)):
89+
path = os.fspath(path)
90+
91+
etrecord_zip = ZipFile(path, "w")
92+
93+
try:
94+
# Write the magic file identifier
95+
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
96+
97+
# Save exported program if present
98+
if self.exported_program is not None:
99+
self._save_exported_program(
100+
etrecord_zip,
101+
ETRecordReservedFileNames.EXPORTED_PROGRAM,
102+
"",
103+
self.exported_program,
104+
)
78105

79-
method_name = f"/{method_name}" if method_name != "" else ""
106+
# Save edge dialect program if present
107+
if self.edge_dialect_program is not None:
108+
self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program)
109+
110+
# Save graph map if present
111+
if self.graph_map is not None:
112+
for module_name, export_module in self.graph_map.items():
113+
# Extract method name from module_name if it contains "/"
114+
if "/" in module_name:
115+
base_name, method_name = module_name.rsplit("/", 1)
116+
self._save_exported_program(
117+
etrecord_zip, base_name, method_name, export_module
118+
)
119+
else:
120+
self._save_exported_program(
121+
etrecord_zip, module_name, "forward", export_module
122+
)
123+
124+
# Save debug handle map
125+
if self._debug_handle_map is not None:
126+
etrecord_zip.writestr(
127+
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
128+
json.dumps(self._debug_handle_map),
129+
)
80130

81-
etrecord_zip.writestr(
82-
f"{module_name}{method_name}", serialized_artifact.exported_program
83-
)
84-
etrecord_zip.writestr(
85-
f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict
86-
)
87-
etrecord_zip.writestr(
88-
f"{module_name}{method_name}_constants", serialized_artifact.constants
89-
)
90-
etrecord_zip.writestr(
91-
f"{module_name}{method_name}_example_inputs",
92-
serialized_artifact.example_inputs,
93-
)
131+
# Save delegate map
132+
if self._delegate_map is not None:
133+
etrecord_zip.writestr(
134+
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
135+
json.dumps(self._delegate_map),
136+
)
94137

138+
# Save reference outputs
139+
if self._reference_outputs is not None:
140+
etrecord_zip.writestr(
141+
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
142+
pickle.dumps(self._reference_outputs),
143+
)
95144

96-
def _handle_export_module(
97-
etrecord_zip: ZipFile,
98-
export_module: Union[
99-
ExirExportedProgram,
100-
EdgeProgramManager,
101-
ExportedProgram,
102-
],
103-
module_name: str,
104-
) -> None:
105-
if isinstance(export_module, ExirExportedProgram):
106-
_handle_exported_program(
107-
etrecord_zip, module_name, "forward", export_module.exported_program
108-
)
109-
elif isinstance(export_module, ExportedProgram):
110-
_handle_exported_program(etrecord_zip, module_name, "forward", export_module)
111-
elif isinstance(
112-
export_module,
113-
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
114-
):
115-
for method in export_module.methods:
116-
_handle_exported_program(
117-
etrecord_zip,
118-
module_name,
119-
method,
120-
export_module.exported_program(method),
121-
)
122-
else:
123-
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
145+
# Save representative inputs
146+
if self._representative_inputs is not None:
147+
etrecord_zip.writestr(
148+
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
149+
pickle.dumps(self._representative_inputs),
150+
)
124151

152+
# Save export graph id
153+
if self.export_graph_id is not None:
154+
etrecord_zip.writestr(
155+
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
156+
json.dumps(self.export_graph_id),
157+
)
125158

126-
def _handle_edge_dialect_exported_program(
127-
etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram
128-
) -> None:
129-
serialized_artifact = serialize(edge_dialect_exported_program)
130-
assert isinstance(serialized_artifact.exported_program, bytes)
159+
finally:
160+
etrecord_zip.close()
131161

132-
etrecord_zip.writestr(
133-
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM,
134-
serialized_artifact.exported_program,
135-
)
136-
etrecord_zip.writestr(
137-
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict",
138-
serialized_artifact.state_dict,
139-
)
140-
etrecord_zip.writestr(
141-
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants",
142-
serialized_artifact.constants,
143-
)
144-
etrecord_zip.writestr(
145-
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs",
146-
serialized_artifact.example_inputs,
147-
)
162+
def _save_exported_program(
163+
self, etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram
164+
) -> None:
165+
"""Save an exported program to the ETRecord zip file."""
166+
serialized_artifact = serialize(ep)
167+
assert isinstance(serialized_artifact.exported_program, bytes)
168+
169+
method_name = f"/{method_name}" if method_name != "" else ""
170+
base_name = f"{module_name}{method_name}"
171+
172+
etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
173+
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
174+
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
175+
etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs)
176+
177+
def _save_edge_dialect_program(
178+
self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram
179+
) -> None:
180+
"""Save the edge dialect program to the ETRecord zip file."""
181+
serialized_artifact = serialize(edge_dialect_program)
182+
assert isinstance(serialized_artifact.exported_program, bytes)
183+
184+
base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
185+
etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
186+
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
187+
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
188+
etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs)
148189

149190

150191
def _get_reference_outputs(
@@ -231,32 +272,27 @@ def generate_etrecord(
231272
Returns:
232273
None
233274
"""
234-
235-
if isinstance(et_record, (str, os.PathLike)):
236-
et_record = os.fspath(et_record) # pyre-ignore
237-
238-
etrecord_zip = ZipFile(et_record, "w")
239-
# Write the magic file identifier that will be used to verify that this file
240-
# is an etrecord when it's used later in the Developer Tools.
241-
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
242-
243-
# Calculate export_graph_id before modifying exported_program
275+
# Prepare data for ETRecord construction
276+
processed_exported_program = None
244277
export_graph_id = 0
278+
processed_edge_dialect_program = None
279+
graph_map = {}
280+
debug_handle_map = None
281+
delegate_map = None
282+
reference_outputs = None
283+
representative_inputs = None
245284

285+
# Process exported program
246286
if exported_program is not None:
247-
# If multiple exported programs are provided, only save forward method
248287
if isinstance(exported_program, dict) and "forward" in exported_program:
249-
exported_program = exported_program["forward"]
250-
251-
if isinstance(exported_program, ExportedProgram):
252-
export_graph_id = id(exported_program.graph)
253-
_handle_exported_program(
254-
etrecord_zip,
255-
ETRecordReservedFileNames.EXPORTED_PROGRAM,
256-
"",
257-
exported_program,
258-
)
288+
processed_exported_program = exported_program["forward"]
289+
elif isinstance(exported_program, ExportedProgram):
290+
processed_exported_program = exported_program
291+
292+
if processed_exported_program is not None:
293+
export_graph_id = id(processed_exported_program.graph)
259294

295+
# Process extra recorded export modules
260296
if extra_recorded_export_modules is not None:
261297
for module_name, export_module in extra_recorded_export_modules.items():
262298
contains_reserved_name = any(
@@ -267,57 +303,49 @@ def generate_etrecord(
267303
raise RuntimeError(
268304
f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
269305
)
270-
_handle_export_module(etrecord_zip, export_module, module_name)
271306

272-
if isinstance(
273-
edge_dialect_program,
274-
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
275-
):
276-
_handle_edge_dialect_exported_program(
277-
etrecord_zip,
278-
edge_dialect_program.exported_program(),
279-
)
307+
# Process different types of export modules
308+
if isinstance(export_module, ExirExportedProgram):
309+
graph_map[f"{module_name}/forward"] = export_module.exported_program
310+
elif isinstance(export_module, ExportedProgram):
311+
graph_map[f"{module_name}/forward"] = export_module
312+
elif isinstance(export_module, (EdgeProgramManager, exir.program._program.EdgeProgramManager)):
313+
for method in export_module.methods:
314+
graph_map[f"{module_name}/{method}"] = export_module.exported_program(method)
315+
else:
316+
raise RuntimeError(f"Unsupported graph module type. {type(export_module)}")
317+
318+
# Process edge dialect program
319+
if isinstance(edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager)):
320+
processed_edge_dialect_program = edge_dialect_program.exported_program()
280321
elif isinstance(edge_dialect_program, ExirExportedProgram):
281-
_handle_edge_dialect_exported_program(
282-
etrecord_zip,
283-
edge_dialect_program.exported_program,
284-
)
322+
processed_edge_dialect_program = edge_dialect_program.exported_program
285323
else:
286-
raise RuntimeError(
287-
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
288-
)
324+
raise RuntimeError(f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}.")
289325

290-
# When a BundledProgram is passed in, extract the reference outputs and save in a file
326+
# Process executorch program
291327
if isinstance(executorch_program, BundledProgram):
292328
reference_outputs = _get_reference_outputs(executorch_program)
293-
etrecord_zip.writestr(
294-
ETRecordReservedFileNames.REFERENCE_OUTPUTS,
295-
# @lint-ignore PYTHONPICKLEISBAD
296-
pickle.dumps(reference_outputs),
297-
)
298-
299329
representative_inputs = _get_representative_inputs(executorch_program)
300-
etrecord_zip.writestr(
301-
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
302-
# @lint-ignore PYTHONPICKLEISBAD
303-
pickle.dumps(representative_inputs),
304-
)
305-
executorch_program = executorch_program.executorch_program
306-
307-
etrecord_zip.writestr(
308-
ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME,
309-
json.dumps(executorch_program.debug_handle_map),
310-
)
330+
debug_handle_map = executorch_program.executorch_program.debug_handle_map
331+
delegate_map = executorch_program.executorch_program.delegate_map
332+
else:
333+
debug_handle_map = executorch_program.debug_handle_map
334+
delegate_map = executorch_program.delegate_map
311335

312-
etrecord_zip.writestr(
313-
ETRecordReservedFileNames.DELEGATE_MAP_NAME,
314-
json.dumps(executorch_program.delegate_map),
336+
# Create ETRecord instance and save
337+
etrecord = ETRecord(
338+
exported_program=processed_exported_program,
339+
export_graph_id=export_graph_id,
340+
edge_dialect_program=processed_edge_dialect_program,
341+
graph_map=graph_map if graph_map else None,
342+
_debug_handle_map=debug_handle_map,
343+
_delegate_map=delegate_map,
344+
_reference_outputs=reference_outputs,
345+
_representative_inputs=representative_inputs,
315346
)
316347

317-
etrecord_zip.writestr(
318-
ETRecordReservedFileNames.EXPORT_GRAPH_ID,
319-
json.dumps(export_graph_id),
320-
)
348+
etrecord.save(et_record)
321349

322350

323351
def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901

0 commit comments

Comments
 (0)