33# CoreML backend for delegating a EdgeProgram to CoreML.
44
55import json
6+
67import shutil
78import uuid
89from dataclasses import asdict , dataclass
910from enum import Enum
1011
1112from pathlib import Path
1213
13- from typing import Dict , final , List
14+ from typing import Any , Dict , final , List , Optional , Tuple
1415
1516import coremltools as ct
1617import executorchcoreml
@@ -30,6 +31,13 @@ class COMPILE_SPEC_KEYS(Enum):
3031 MODEL_COMPUTE_PRECISION = "model_compute_precision"
3132
3233
34+ class MODEL_PATHS (Enum ):
35+ MODEL = "model.mlpackage"
36+ COMPILED_MODEL = "model.mlmodelc"
37+ METADATA = "metadata.json"
38+ DEBUG_INFO = "debug_info.json"
39+
40+
3341@dataclass
3442class ModelMetadata :
3543 # The model input names.
@@ -40,6 +48,16 @@ class ModelMetadata:
4048 identifier : str
4149
4250
51+ @dataclass
52+ class ModelDebugInfo :
53+ # Version info.
54+ versionInfo : Dict [str , str ]
55+ # Mapping from debug symbol to operation path.
56+ debugSymbolToOperationPath : Dict [str , List [Dict [str , str ]]]
57+ # Mapping from debug symbol to handle.
58+ debugSymbolToHandles : Dict [str , List [int ]]
59+
60+
4361@final
4462class CoreMLBackend (BackendDetails ):
4563 class MODEL_TYPE (Enum ):
@@ -165,53 +183,163 @@ def generate_compile_specs(
165183 return compile_specs
166184
167185 @staticmethod
168- def model_metadata_from_spec (model_spec : ct .proto .Model_pb2 ) -> Dict [str , str ]:
186+ def model_metadata_from_spec (
187+ model_spec : ct .proto .Model_pb2 , identifier : str
188+ ) -> Dict [str , str ]:
169189 input_names : List [str ] = [input .name for input in model_spec .description .input ]
170190 output_names = [output .name for output in model_spec .description .output ]
171- identifier = uuid .uuid4 ()
172191
173192 return ModelMetadata (
174- inputNames = input_names , outputNames = output_names , identifier = str (identifier )
193+ inputNames = input_names , outputNames = output_names , identifier = identifier
194+ )
195+
196+ @staticmethod
197+ def get_debug_symbol (operation_path : List [Dict [str , str ]]) -> Optional [str ]:
198+ if len (operation_path ) == 0 :
199+ return None
200+
201+ operator_name : Optional [str ] = operation_path [- 1 ].get ("Operator" , None )
202+ output_name : Optional [str ] = operation_path [- 1 ].get ("Output" , None )
203+ if output_name is None or operator_name is None :
204+ return None
205+
206+ return output_name + ":" + operator_name
207+
208+ @staticmethod
209+ def get_model_debug_info (model_package_dir : Path ) -> Optional [ModelDebugInfo ]:
210+ delegate_info_file = model_package_dir / "executorch_debug_handle_mapping.json"
211+
212+ if not delegate_info_file .is_file ():
213+ return None
214+
215+ delegate_info : Optional [Dict [str , Any ]] = None
216+
217+ try :
218+ with open (delegate_info_file ) as f :
219+ delegate_info = json .load (f )
220+ except ValueError :
221+ return None
222+
223+ if delegate_info is None :
224+ return None
225+
226+ debug_handle_to_operation_path_mapping : Optional [Dict [str , Any ]] = (
227+ delegate_info .get ("mapping" , None )
228+ )
229+
230+ if debug_handle_to_operation_path_mapping is None :
231+ return None
232+
233+ debug_symbol_to_operation_path : Dict [str , List [Dict [str , str ]]] = {}
234+ debug_symbol_to_handles : Dict [str , List [int ]] = {}
235+ for (
236+ debug_handle ,
237+ operation_paths ,
238+ ) in debug_handle_to_operation_path_mapping .items ():
239+ debug_handle_value : Optional [int ] = None
240+ try :
241+ debug_handle_value = int (debug_handle )
242+ except ValueError :
243+ debug_handle_value = None
244+
245+ if debug_handle_value is None :
246+ continue
247+
248+ for operation_path in operation_paths :
249+ debug_symbol : Optional [str ] = CoreMLBackend .get_debug_symbol (
250+ operation_path = operation_path
251+ )
252+
253+ if debug_symbol is None :
254+ continue
255+
256+ debug_handle_values : List [int ] = debug_symbol_to_handles .get (
257+ debug_symbol , []
258+ )
259+ debug_handle_values .append (debug_handle_value )
260+ debug_symbol_to_handles [debug_symbol ] = debug_handle_values
261+
262+ debug_symbol_to_operation_path [debug_symbol ] = operation_path
263+
264+ version_info : Dict [str , str ] = delegate_info .get ("version" , {})
265+
266+ return ModelDebugInfo (
267+ versionInfo = version_info ,
268+ debugSymbolToOperationPath = debug_symbol_to_operation_path ,
269+ debugSymbolToHandles = debug_symbol_to_handles ,
175270 )
176271
177272 @staticmethod
178- def to_bytes (mlmodel : ct .models .MLModel , model_type : MODEL_TYPE ) -> bytes :
179- dir_path : Path = Path ("tmp" )
273+ def save_model_metadata (model_metadata : ModelMetadata , model_dir_path : Path ):
274+ # Store model metadata.
275+ model_metadata_path = Path (model_dir_path ) / MODEL_PATHS .METADATA .value
276+ model_metadata_json = json .dumps (asdict (model_metadata ))
277+ with open (model_metadata_path , "w" ) as outfile :
278+ outfile .write (model_metadata_json )
279+
280+ @staticmethod
281+ def save_model_debug_info (model_debug_info : ModelDebugInfo , model_dir_path : Path ):
282+ # Store model debug info.
283+ model_debug_info_path = Path (model_dir_path ) / MODEL_PATHS .DEBUG_INFO .value
284+ model_debug_info_json = json .dumps (asdict (model_debug_info ))
285+ with open (model_debug_info_path , "w" ) as outfile :
286+ outfile .write (model_debug_info_json )
287+
288+ @staticmethod
289+ def preprocess_model (
290+ mlmodel : ct .models .MLModel , model_type : MODEL_TYPE
291+ ) -> PreprocessResult :
292+ identifier = str (uuid .uuid4 ())
293+ dir_path : Path = Path ("tmp" ) / identifier
180294 model_dir_path : Path = dir_path / "lowered_module"
181295 model_spec : ct .proto .Model_pb2 = mlmodel .get_spec ()
182296 model_metadata : ModelMetadata = CoreMLBackend .model_metadata_from_spec (
183- model_spec
297+ model_spec = model_spec ,
298+ identifier = identifier ,
184299 )
185- match model_type :
186- case CoreMLBackend .MODEL_TYPE .MODEL :
187- # Store model.
188- model_path = model_dir_path / "model.mlpackage"
189- mlmodel .save (model_path )
190300
301+ # Save model.
302+ model_path = model_dir_path / MODEL_PATHS .MODEL .value
303+ mlmodel .save (model_path )
304+ # Extract delegate mapping file.
305+ model_debug_info : Optional [ModelDebugInfo ] = CoreMLBackend .get_model_debug_info (
306+ model_path
307+ )
308+
309+ match model_type :
191310 case CoreMLBackend .MODEL_TYPE .COMPILED_MODEL :
192- # Store compiled model
193- model_path = model_dir_path / "model.mlmodelc"
311+ shutil . rmtree ( str ( model_path . resolve ()))
312+ model_path = model_dir_path / MODEL_PATHS . COMPILED_MODEL . value
194313 compiled_model_path = mlmodel .get_compiled_model_path ()
195-
196- shutil .copytree (
314+ shutil .move (
197315 compiled_model_path ,
198316 str (model_path .resolve ()),
199- dirs_exist_ok = True ,
200317 )
201318
202- # Store model metadata.
203- model_metadata_path = Path (model_dir_path ) / "metadata.json"
204- model_metadata_json = json .dumps (asdict (model_metadata ))
205- with open (model_metadata_path , "w" ) as outfile :
206- outfile .write (model_metadata_json )
319+ case _:
320+ pass
207321
208- # flatten directory contents and convert it to bytes
209- flattened_bytes = executorchcoreml .flatten_directory_contents (
322+ CoreMLBackend .save_model_metadata (
323+ model_metadata = model_metadata , model_dir_path = model_dir_path
324+ )
325+ if model_debug_info is not None :
326+ CoreMLBackend .save_model_debug_info (
327+ model_debug_info = model_debug_info , model_dir_path = model_dir_path
328+ )
329+
330+ processed_bytes : bytes = executorchcoreml .flatten_directory_contents (
210331 str (model_dir_path .resolve ())
211332 )
212333
213- shutil .rmtree (str (model_dir_path .resolve ()))
214- return flattened_bytes
334+ debug_handle_map : Optional [Dict [str , Tuple [int ]]] = None
335+ if model_debug_info is not None :
336+ debug_handle_map = model_debug_info .debugSymbolToHandles
337+
338+ shutil .rmtree (str (dir_path .resolve ()))
339+ return PreprocessResult (
340+ processed_bytes = processed_bytes ,
341+ debug_handle_map = debug_handle_map ,
342+ )
215343
216344 @classmethod
217345 def preprocess (
@@ -235,25 +363,14 @@ def preprocess(
235363 CoreMLBackend .min_deployment_target_from_compile_specs (module_compile_specs )
236364 )
237365
238- skip_model_load : bool = False
239- match model_type :
240- case CoreMLBackend .MODEL_TYPE .MODEL :
241- skip_model_load = True
242-
243- case CoreMLBackend .MODEL_TYPE .COMPILED_MODEL :
244- skip_model_load = False
245-
246366 mlmodel = ct .convert (
247367 model = edge_program ,
248368 source = "pytorch" ,
249369 convert_to = "mlprogram" ,
250370 pass_pipeline = ct .PassPipeline .DEFAULT ,
251- skip_model_load = skip_model_load ,
371+ skip_model_load = False ,
252372 compute_precision = model_compute_precision ,
253373 minimum_deployment_target = minimum_deployment_target ,
254374 )
255375
256- processed_bytes = CoreMLBackend .to_bytes (mlmodel , model_type = model_type )
257- return PreprocessResult (
258- processed_bytes = processed_bytes ,
259- )
376+ return CoreMLBackend .preprocess_model (mlmodel , model_type = model_type )
0 commit comments