@@ -55,96 +55,137 @@ class ETRecordReservedFileNames(StrEnum):
5555 REPRESENTATIVE_INPUTS = "representative_inputs"
5656
5757
58- @dataclass
5958class ETRecord :
60- exported_program : Optional [ExportedProgram ] = None
61- export_graph_id : Optional [int ] = None
62- edge_dialect_program : Optional [ExportedProgram ] = None
63- graph_map : Optional [Dict [str , ExportedProgram ]] = None
64- _debug_handle_map : Optional [Dict [int , Union [int , List [int ]]]] = None
65- _delegate_map : Optional [
66- Dict [str , Dict [int , Dict [str , Union [str , _DelegateDebugIdentifierMap ]]]]
67- ] = None
68- _reference_outputs : Optional [Dict [str , List [ProgramOutput ]]] = None
69- _representative_inputs : Optional [List [ProgramOutput ]] = None
70-
71-
72- def _handle_exported_program (
73- etrecord_zip : ZipFile , module_name : str , method_name : str , ep : ExportedProgram
74- ) -> None :
75- assert isinstance (ep , ExportedProgram )
76- serialized_artifact = serialize (ep )
77- assert isinstance (serialized_artifact .exported_program , bytes )
59+ def __init__ (
60+ self ,
61+ exported_program : Optional [ExportedProgram ] = None ,
62+ export_graph_id : Optional [int ] = None ,
63+ edge_dialect_program : Optional [ExportedProgram ] = None ,
64+ graph_map : Optional [Dict [str , ExportedProgram ]] = None ,
65+ _debug_handle_map : Optional [Dict [int , Union [int , List [int ]]]] = None ,
66+ _delegate_map : Optional [
67+ Dict [str , Dict [int , Dict [str , Union [str , _DelegateDebugIdentifierMap ]]]]
68+ ] = None ,
69+ _reference_outputs : Optional [Dict [str , List [ProgramOutput ]]] = None ,
70+ _representative_inputs : Optional [List [ProgramOutput ]] = None ,
71+ ):
72+ self .exported_program = exported_program
73+ self .export_graph_id = export_graph_id
74+ self .edge_dialect_program = edge_dialect_program
75+ self .graph_map = graph_map
76+ self ._debug_handle_map = _debug_handle_map
77+ self ._delegate_map = _delegate_map
78+ self ._reference_outputs = _reference_outputs
79+ self ._representative_inputs = _representative_inputs
80+
81+ def save (self , path : Union [str , os .PathLike , BinaryIO , IO [bytes ]]) -> None :
82+ """
83+ Serialize and save the ETRecord to the specified path.
84+
85+ Args:
86+ path: Path where the ETRecord file will be saved to.
87+ """
88+ if isinstance (path , (str , os .PathLike )):
89+ path = os .fspath (path )
90+
91+ etrecord_zip = ZipFile (path , "w" )
92+
93+ try :
94+ # Write the magic file identifier
95+ etrecord_zip .writestr (ETRecordReservedFileNames .ETRECORD_IDENTIFIER , "" )
96+
97+ # Save exported program if present
98+ if self .exported_program is not None :
99+ self ._save_exported_program (
100+ etrecord_zip ,
101+ ETRecordReservedFileNames .EXPORTED_PROGRAM ,
102+ "" ,
103+ self .exported_program ,
104+ )
78105
79- method_name = f"/{ method_name } " if method_name != "" else ""
106+ # Save edge dialect program if present
107+ if self .edge_dialect_program is not None :
108+ self ._save_edge_dialect_program (etrecord_zip , self .edge_dialect_program )
109+
110+ # Save graph map if present
111+ if self .graph_map is not None :
112+ for module_name , export_module in self .graph_map .items ():
113+ # Extract method name from module_name if it contains "/"
114+ if "/" in module_name :
115+ base_name , method_name = module_name .rsplit ("/" , 1 )
116+ self ._save_exported_program (
117+ etrecord_zip , base_name , method_name , export_module
118+ )
119+ else :
120+ self ._save_exported_program (
121+ etrecord_zip , module_name , "forward" , export_module
122+ )
123+
124+ # Save debug handle map
125+ if self ._debug_handle_map is not None :
126+ etrecord_zip .writestr (
127+ ETRecordReservedFileNames .DEBUG_HANDLE_MAP_NAME ,
128+ json .dumps (self ._debug_handle_map ),
129+ )
80130
81- etrecord_zip .writestr (
82- f"{ module_name } { method_name } " , serialized_artifact .exported_program
83- )
84- etrecord_zip .writestr (
85- f"{ module_name } { method_name } _state_dict" , serialized_artifact .state_dict
86- )
87- etrecord_zip .writestr (
88- f"{ module_name } { method_name } _constants" , serialized_artifact .constants
89- )
90- etrecord_zip .writestr (
91- f"{ module_name } { method_name } _example_inputs" ,
92- serialized_artifact .example_inputs ,
93- )
131+ # Save delegate map
132+ if self ._delegate_map is not None :
133+ etrecord_zip .writestr (
134+ ETRecordReservedFileNames .DELEGATE_MAP_NAME ,
135+ json .dumps (self ._delegate_map ),
136+ )
94137
138+ # Save reference outputs
139+ if self ._reference_outputs is not None :
140+ etrecord_zip .writestr (
141+ ETRecordReservedFileNames .REFERENCE_OUTPUTS ,
142+ pickle .dumps (self ._reference_outputs ),
143+ )
95144
96- def _handle_export_module (
97- etrecord_zip : ZipFile ,
98- export_module : Union [
99- ExirExportedProgram ,
100- EdgeProgramManager ,
101- ExportedProgram ,
102- ],
103- module_name : str ,
104- ) -> None :
105- if isinstance (export_module , ExirExportedProgram ):
106- _handle_exported_program (
107- etrecord_zip , module_name , "forward" , export_module .exported_program
108- )
109- elif isinstance (export_module , ExportedProgram ):
110- _handle_exported_program (etrecord_zip , module_name , "forward" , export_module )
111- elif isinstance (
112- export_module ,
113- (EdgeProgramManager , exir .program ._program .EdgeProgramManager ),
114- ):
115- for method in export_module .methods :
116- _handle_exported_program (
117- etrecord_zip ,
118- module_name ,
119- method ,
120- export_module .exported_program (method ),
121- )
122- else :
123- raise RuntimeError (f"Unsupported graph module type. { type (export_module )} " )
145+ # Save representative inputs
146+ if self ._representative_inputs is not None :
147+ etrecord_zip .writestr (
148+ ETRecordReservedFileNames .REPRESENTATIVE_INPUTS ,
149+ pickle .dumps (self ._representative_inputs ),
150+ )
124151
152+ # Save export graph id
153+ if self .export_graph_id is not None :
154+ etrecord_zip .writestr (
155+ ETRecordReservedFileNames .EXPORT_GRAPH_ID ,
156+ json .dumps (self .export_graph_id ),
157+ )
125158
126- def _handle_edge_dialect_exported_program (
127- etrecord_zip : ZipFile , edge_dialect_exported_program : ExportedProgram
128- ) -> None :
129- serialized_artifact = serialize (edge_dialect_exported_program )
130- assert isinstance (serialized_artifact .exported_program , bytes )
159+ finally :
160+ etrecord_zip .close ()
131161
132- etrecord_zip .writestr (
133- ETRecordReservedFileNames .EDGE_DIALECT_EXPORTED_PROGRAM ,
134- serialized_artifact .exported_program ,
135- )
136- etrecord_zip .writestr (
137- f"{ ETRecordReservedFileNames .EDGE_DIALECT_EXPORTED_PROGRAM } _state_dict" ,
138- serialized_artifact .state_dict ,
139- )
140- etrecord_zip .writestr (
141- f"{ ETRecordReservedFileNames .EDGE_DIALECT_EXPORTED_PROGRAM } _constants" ,
142- serialized_artifact .constants ,
143- )
144- etrecord_zip .writestr (
145- f"{ ETRecordReservedFileNames .EDGE_DIALECT_EXPORTED_PROGRAM } _example_inputs" ,
146- serialized_artifact .example_inputs ,
147- )
162+ def _save_exported_program (
163+ self , etrecord_zip : ZipFile , module_name : str , method_name : str , ep : ExportedProgram
164+ ) -> None :
165+ """Save an exported program to the ETRecord zip file."""
166+ serialized_artifact = serialize (ep )
167+ assert isinstance (serialized_artifact .exported_program , bytes )
168+
169+ method_name = f"/{ method_name } " if method_name != "" else ""
170+ base_name = f"{ module_name } { method_name } "
171+
172+ etrecord_zip .writestr (base_name , serialized_artifact .exported_program )
173+ etrecord_zip .writestr (f"{ base_name } _state_dict" , serialized_artifact .state_dict )
174+ etrecord_zip .writestr (f"{ base_name } _constants" , serialized_artifact .constants )
175+ etrecord_zip .writestr (f"{ base_name } _example_inputs" , serialized_artifact .example_inputs )
176+
177+ def _save_edge_dialect_program (
178+ self , etrecord_zip : ZipFile , edge_dialect_program : ExportedProgram
179+ ) -> None :
180+ """Save the edge dialect program to the ETRecord zip file."""
181+ serialized_artifact = serialize (edge_dialect_program )
182+ assert isinstance (serialized_artifact .exported_program , bytes )
183+
184+ base_name = ETRecordReservedFileNames .EDGE_DIALECT_EXPORTED_PROGRAM
185+ etrecord_zip .writestr (base_name , serialized_artifact .exported_program )
186+ etrecord_zip .writestr (f"{ base_name } _state_dict" , serialized_artifact .state_dict )
187+ etrecord_zip .writestr (f"{ base_name } _constants" , serialized_artifact .constants )
188+ etrecord_zip .writestr (f"{ base_name } _example_inputs" , serialized_artifact .example_inputs )
148189
149190
150191def _get_reference_outputs (
@@ -231,32 +272,27 @@ def generate_etrecord(
231272 Returns:
232273 None
233274 """
234-
235- if isinstance (et_record , (str , os .PathLike )):
236- et_record = os .fspath (et_record ) # pyre-ignore
237-
238- etrecord_zip = ZipFile (et_record , "w" )
239- # Write the magic file identifier that will be used to verify that this file
240- # is an etrecord when it's used later in the Developer Tools.
241- etrecord_zip .writestr (ETRecordReservedFileNames .ETRECORD_IDENTIFIER , "" )
242-
243- # Calculate export_graph_id before modifying exported_program
275+ # Prepare data for ETRecord construction
276+ processed_exported_program = None
244277 export_graph_id = 0
278+ processed_edge_dialect_program = None
279+ graph_map = {}
280+ debug_handle_map = None
281+ delegate_map = None
282+ reference_outputs = None
283+ representative_inputs = None
245284
285+ # Process exported program
246286 if exported_program is not None :
247- # If multiple exported programs are provided, only save forward method
248287 if isinstance (exported_program , dict ) and "forward" in exported_program :
249- exported_program = exported_program ["forward" ]
250-
251- if isinstance (exported_program , ExportedProgram ):
252- export_graph_id = id (exported_program .graph )
253- _handle_exported_program (
254- etrecord_zip ,
255- ETRecordReservedFileNames .EXPORTED_PROGRAM ,
256- "" ,
257- exported_program ,
258- )
288+ processed_exported_program = exported_program ["forward" ]
289+ elif isinstance (exported_program , ExportedProgram ):
290+ processed_exported_program = exported_program
291+
292+ if processed_exported_program is not None :
293+ export_graph_id = id (processed_exported_program .graph )
259294
295+ # Process extra recorded export modules
260296 if extra_recorded_export_modules is not None :
261297 for module_name , export_module in extra_recorded_export_modules .items ():
262298 contains_reserved_name = any (
@@ -267,57 +303,49 @@ def generate_etrecord(
267303 raise RuntimeError (
268304 f"The name { module_name } provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace."
269305 )
270- _handle_export_module (etrecord_zip , export_module , module_name )
271306
272- if isinstance (
273- edge_dialect_program ,
274- (EdgeProgramManager , exir .program ._program .EdgeProgramManager ),
275- ):
276- _handle_edge_dialect_exported_program (
277- etrecord_zip ,
278- edge_dialect_program .exported_program (),
279- )
307+ # Process different types of export modules
308+ if isinstance (export_module , ExirExportedProgram ):
309+ graph_map [f"{ module_name } /forward" ] = export_module .exported_program
310+ elif isinstance (export_module , ExportedProgram ):
311+ graph_map [f"{ module_name } /forward" ] = export_module
312+ elif isinstance (export_module , (EdgeProgramManager , exir .program ._program .EdgeProgramManager )):
313+ for method in export_module .methods :
314+ graph_map [f"{ module_name } /{ method } " ] = export_module .exported_program (method )
315+ else :
316+ raise RuntimeError (f"Unsupported graph module type. { type (export_module )} " )
317+
318+ # Process edge dialect program
319+ if isinstance (edge_dialect_program , (EdgeProgramManager , exir .program ._program .EdgeProgramManager )):
320+ processed_edge_dialect_program = edge_dialect_program .exported_program ()
280321 elif isinstance (edge_dialect_program , ExirExportedProgram ):
281- _handle_edge_dialect_exported_program (
282- etrecord_zip ,
283- edge_dialect_program .exported_program ,
284- )
322+ processed_edge_dialect_program = edge_dialect_program .exported_program
285323 else :
286- raise RuntimeError (
287- f"Unsupported type of edge_dialect_program passed in { type (edge_dialect_program )} ."
288- )
324+ raise RuntimeError (f"Unsupported type of edge_dialect_program passed in { type (edge_dialect_program )} ." )
289325
290- # When a BundledProgram is passed in, extract the reference outputs and save in a file
326+ # Process executorch program
291327 if isinstance (executorch_program , BundledProgram ):
292328 reference_outputs = _get_reference_outputs (executorch_program )
293- etrecord_zip .writestr (
294- ETRecordReservedFileNames .REFERENCE_OUTPUTS ,
295- # @lint-ignore PYTHONPICKLEISBAD
296- pickle .dumps (reference_outputs ),
297- )
298-
299329 representative_inputs = _get_representative_inputs (executorch_program )
300- etrecord_zip .writestr (
301- ETRecordReservedFileNames .REPRESENTATIVE_INPUTS ,
302- # @lint-ignore PYTHONPICKLEISBAD
303- pickle .dumps (representative_inputs ),
304- )
305- executorch_program = executorch_program .executorch_program
306-
307- etrecord_zip .writestr (
308- ETRecordReservedFileNames .DEBUG_HANDLE_MAP_NAME ,
309- json .dumps (executorch_program .debug_handle_map ),
310- )
330+ debug_handle_map = executorch_program .executorch_program .debug_handle_map
331+ delegate_map = executorch_program .executorch_program .delegate_map
332+ else :
333+ debug_handle_map = executorch_program .debug_handle_map
334+ delegate_map = executorch_program .delegate_map
311335
312- etrecord_zip .writestr (
313- ETRecordReservedFileNames .DELEGATE_MAP_NAME ,
314- json .dumps (executorch_program .delegate_map ),
336+ # Create ETRecord instance and save
337+ etrecord = ETRecord (
338+ exported_program = processed_exported_program ,
339+ export_graph_id = export_graph_id ,
340+ edge_dialect_program = processed_edge_dialect_program ,
341+ graph_map = graph_map if graph_map else None ,
342+ _debug_handle_map = debug_handle_map ,
343+ _delegate_map = delegate_map ,
344+ _reference_outputs = reference_outputs ,
345+ _representative_inputs = representative_inputs ,
315346 )
316347
317- etrecord_zip .writestr (
318- ETRecordReservedFileNames .EXPORT_GRAPH_ID ,
319- json .dumps (export_graph_id ),
320- )
348+ etrecord .save (et_record )
321349
322350
323351def parse_etrecord (etrecord_path : str ) -> ETRecord : # noqa: C901
0 commit comments