1717import coremltools as ct
1818import coremltools .optimize as cto
1919from executorch .backends .apple .coreml import executorchcoreml
20+ from numpy import isin
2021from executorch .backends .apple .coreml .logging import get_coreml_log_level
2122from executorch .exir .backend .backend_details import (
2223 BackendDetails ,
@@ -37,6 +38,7 @@ class COMPILE_SPEC_KEYS(Enum):
3738 MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3839 MODEL_COMPUTE_PRECISION = "model_compute_precision"
3940 OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
41+ CT_INPUTS = "ct_inputs"
4042
4143
4244class MODEL_PATHS (Enum ):
@@ -214,6 +216,152 @@ def op_linear_quantizer_config_from_compile_specs(
214216
215217 return None
216218
219+
220+ @staticmethod
221+ def generate_ct_inputs_compile_spec (
222+ ct_inputs : List [ct .TensorType ],
223+ ) -> CompileSpec :
224+ """
225+ Returns the compile spec representing the model inputs
226+ Generally this is not needed, but is used to specify things that cannot be inferred from
227+ the exported program, like enumerated shapes
228+ """
229+
230+ def _is_int_shape (seq ):
231+ return isinstance (seq , (list , tuple )) and all (isinstance (x , int ) for x in seq )
232+
233+ def _serialize_shape (shape ):
234+ # Case 1: None
235+ if shape is None :
236+ return {"kind" : "fixed" , "shape" : None }
237+
238+ # Case 2: Plain list/tuple of ints
239+ if _is_int_shape (shape ):
240+ return {"kind" : "fixed" , "shape" : list (shape )}
241+
242+ # Case 3: EnumeratedShapes (with ct.Shape entries)
243+ if isinstance (shape , ct .EnumeratedShapes ):
244+ shapes = []
245+ for s in shape .shapes :
246+ # ct.Shape(...) -> s.shape should be a tuple of ints
247+ if not _is_int_shape (s .shape ):
248+ raise TypeError ("EnumeratedShapes entries must be tuples/lists of ints" )
249+ shapes .append (list (s .shape ))
250+ default = None
251+ if shape .default is not None :
252+ if not _is_int_shape (shape .default .shape ):
253+ raise TypeError ("EnumeratedShapes.default must be a tuple/list of ints" )
254+ default = list (shape .default .shape )
255+ return {"kind" : "enumerated" , "shapes" : shapes , "default" : default }
256+
257+ # Anything else is out of scope for now
258+ raise TypeError ("Shape must be EnumeratedShapes, a list/tuple of ints, or None" )
259+
260+ def tensor_type_to_dict (t : ct .TensorType ):
261+ assert isinstance (t , ct .TensorType )
262+ for attr in ["name" , "dtype" , "default_value" ]:
263+ assert getattr (t , attr ) is None , f"{ attr } cannot be given a value"
264+ return {
265+ "kind" : "TensorType" ,
266+ "name" : t .name ,
267+ "shape" : _serialize_shape (t .shape ),
268+ "dtype" : t .dtype ,
269+ "default_value" : t .default_value ,
270+ }
271+
272+ str_representation = json .dumps ([tensor_type_to_dict (ct_in ) for ct_in in ct_inputs ])
273+ byte_representation = str_representation .encode ("utf-8" )
274+ return CompileSpec (
275+ COMPILE_SPEC_KEYS .CT_INPUTS .value ,
276+ byte_representation ,
277+ )
278+
279+ @staticmethod
280+ def ct_inputs_from_compile_specs (
281+ compile_specs : List ["CompileSpec" ],
282+ ) -> Optional [List [ct .TensorType ]]:
283+ """
284+ Returns the model's ct.inputs by parsing the list of compile specs.
285+
286+ Expected JSON schema per entry (as produced by generate_ct_inputs_compile_spec):
287+ {
288+ "kind": "TensorType",
289+ "name": "<non-empty string>",
290+ "shape": {
291+ "kind": "fixed", "shape": [int, ...] | null
292+ | "kind": "enumerated", "shapes": [[int, ...], ...], "default": [int, ...] | null
293+ }
294+ }
295+ """
296+ def _is_int_shape (seq ):
297+ return isinstance (seq , (list , tuple )) and all (isinstance (x , int ) for x in seq )
298+
299+ def _parse_shape (shape_json ):
300+ if not isinstance (shape_json , dict ) or "kind" not in shape_json :
301+ raise ValueError ("Invalid shape JSON: missing 'kind'" )
302+
303+ kind = shape_json ["kind" ]
304+
305+ # Case: fixed
306+ if kind == "fixed" :
307+ shp = shape_json .get ("shape" , None )
308+ if shp is None :
309+ return None
310+ if not _is_int_shape (shp ):
311+ raise TypeError ("Fixed shape must be a list/tuple of ints or null" )
312+ return tuple (shp )
313+
314+ # Case: enumerated
315+ if kind == "enumerated" :
316+ shapes = shape_json .get ("shapes" , None )
317+ if not isinstance (shapes , list ) or not shapes :
318+ raise ValueError ("Enumerated shape must have non-empty 'shapes' list" )
319+
320+ parsed_shapes = []
321+ for s in shapes :
322+ if not _is_int_shape (s ):
323+ raise TypeError ("EnumeratedShapes entries must be lists of ints" )
324+ parsed_shapes .append (ct .Shape (tuple (s )))
325+
326+ default = shape_json .get ("default" , None )
327+ default_shape = None
328+ if default is not None :
329+ if not _is_int_shape (default ):
330+ raise TypeError ("EnumeratedShapes.default must be a list of ints" )
331+ default_shape = ct .Shape (tuple (default ))
332+
333+ return ct .EnumeratedShapes (shapes = parsed_shapes , default = default_shape )
334+
335+ raise ValueError (f"Unsupported shape kind: { kind } " )
336+
337+ for compile_spec in compile_specs :
338+ if compile_spec .key == COMPILE_SPEC_KEYS .CT_INPUTS .value :
339+ raw = compile_spec .value .decode ("utf-8" )
340+ payload = json .loads (raw )
341+
342+ if not isinstance (payload , list ):
343+ raise ValueError ("CT_INPUTS payload must be a list" )
344+
345+ ct_inputs : List [ct .TensorType ] = []
346+ for entry in payload :
347+ if not isinstance (entry , dict ) or entry .get ("kind" ) != "TensorType" :
348+ raise ValueError ("Each entry must be a dict with kind == 'TensorType'" )
349+
350+ name = entry .get ("name" , "" )
351+ if not isinstance (name , str ) or not name :
352+ raise ValueError ("TensorType.name must be a non-empty string" )
353+
354+ shape_json = entry .get ("shape" , None )
355+ shape = _parse_shape (shape_json ) if shape_json is not None else None
356+
357+ # Per your current contract, dtype/default_value must be None (and were omitted).
358+ # So we only pass name + shape here.
359+ ct_inputs .append (ct .TensorType (name = name , shape = shape ))
360+
361+ return ct_inputs
362+
363+ return None
364+
217365 @staticmethod
218366 def generate_compile_specs (
219367 compute_unit : ct .ComputeUnit = ct .ComputeUnit .ALL ,
@@ -446,6 +594,9 @@ def preprocess(
446594 op_linear_quantizer_config = (
447595 CoreMLBackend .op_linear_quantizer_config_from_compile_specs (compile_specs )
448596 )
597+ enumerated_shapes = (
598+ CoreMLBackend .enumerate_shapes_from_compile_specs (compile_specs )
599+ )
449600
450601 # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
451602 # get_compiled_model_path() requires a loaded model.
0 commit comments