1717import coremltools as ct
1818import coremltools .optimize as cto
1919from executorch .backends .apple .coreml import executorchcoreml
20- from numpy import isin
20+ from executorch .backends .apple .coreml .enumerated_shape_utils import (
21+ _get_ct_inputs ,
22+ _SymbolicShapeToEnumeratedShapeMap ,
23+ )
2124from executorch .backends .apple .coreml .logging import get_coreml_log_level
2225from executorch .exir .backend .backend_details import (
2326 BackendDetails ,
@@ -38,7 +41,7 @@ class COMPILE_SPEC_KEYS(Enum):
3841 MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3942 MODEL_COMPUTE_PRECISION = "model_compute_precision"
4043 OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
41- CT_INPUTS = "ct_inputs "
44+ ENUMERATED_SHAPES = "enumerated_shapes "
4245
4346
4447class MODEL_PATHS (Enum ):
@@ -145,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec(
145148 @staticmethod
146149 def min_deployment_target_from_compile_specs (
147150 compile_specs : List [CompileSpec ],
148- ) -> ct .target :
151+ ) -> Optional [ ct .target ] :
149152 """
150153 Returns the minimum deployment target by parsing the list of compile specs.
151154 """
@@ -216,151 +219,53 @@ def op_linear_quantizer_config_from_compile_specs(
216219
217220 return None
218221
219-
220222 @staticmethod
221- def generate_ct_inputs_compile_spec (
222- ct_inputs : List [ct .TensorType ],
223+ def generate_enumerated_shapes_compile_spec (
224+ ep : ExportedProgram ,
225+ enumerated_shapes : Dict [str , List [List [int ]]],
223226 ) -> CompileSpec :
224227 """
225- Returns the compile spec representing the model inputs
226- Generally this is not needed, but is used to specify things that cannot be inferred from
227- the exported program, like enumerated shapes
228- """
228+ Returns the compile spec representing the model enumerated shapes
229+ enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g.,
229230
230- def _is_int_shape (seq ):
231- return isinstance (seq , (list , tuple )) and all (isinstance (x , int ) for x in seq )
232-
233- def _serialize_shape (shape ):
234- # Case 1: None
235- if shape is None :
236- return {"kind" : "fixed" , "shape" : None }
237-
238- # Case 2: Plain list/tuple of ints
239- if _is_int_shape (shape ):
240- return {"kind" : "fixed" , "shape" : list (shape )}
241-
242- # Case 3: EnumeratedShapes (with ct.Shape entries)
243- if isinstance (shape , ct .EnumeratedShapes ):
244- shapes = []
245- for s in shape .shapes :
246- # ct.Shape(...) -> s.shape should be a tuple of ints
247- if not _is_int_shape (s .shape ):
248- raise TypeError ("EnumeratedShapes entries must be tuples/lists of ints" )
249- shapes .append (list (s .shape ))
250- default = None
251- if shape .default is not None :
252- if not _is_int_shape (shape .default .shape ):
253- raise TypeError ("EnumeratedShapes.default must be a tuple/list of ints" )
254- default = list (shape .default .shape )
255- return {"kind" : "enumerated" , "shapes" : shapes , "default" : default }
256-
257- # Anything else is out of scope for now
258- raise TypeError ("Shape must be EnumeratedShapes, a list/tuple of ints, or None" )
259-
260- def tensor_type_to_dict (t : ct .TensorType ):
261- assert isinstance (t , ct .TensorType )
262- for attr in ["name" , "dtype" , "default_value" ]:
263- assert getattr (t , attr ) is None , f"{ attr } cannot be given a value"
264- return {
265- "kind" : "TensorType" ,
266- "name" : t .name ,
267- "shape" : _serialize_shape (t .shape ),
268- "dtype" : t .dtype ,
269- "default_value" : t .default_value ,
270- }
231+ enumerated_shapes = {
232+ {"x": [[1, 1, 24], [8, 9, 24]]
233+ {"y": [[1, 6], [30, 6]],
234+ ]
235+
236+ means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6].
271237
272- str_representation = json .dumps ([tensor_type_to_dict (ct_in ) for ct_in in ct_inputs ])
238+ Only multiple inputs can have enumerated shapes if using iOS18 or later.
239+ In this case, each input must have the same number of enumerated shapes, and these shapes are tied together
240+ by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6],
241+ or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6].
242+
243+ Passing incorrect shapes at runtime will result in an error.
244+ """
245+ emap = _SymbolicShapeToEnumeratedShapeMap .from_exported_program (
246+ ep ,
247+ enumerated_shapes ,
248+ )
249+ str_representation = emap .to_json ()
273250 byte_representation = str_representation .encode ("utf-8" )
274251 return CompileSpec (
275- COMPILE_SPEC_KEYS .CT_INPUTS .value ,
252+ COMPILE_SPEC_KEYS .ENUMERATED_SHAPES .value ,
276253 byte_representation ,
277254 )
278255
279256 @staticmethod
280- def ct_inputs_from_compile_specs (
281- compile_specs : List [" CompileSpec" ],
282- ) -> Optional [ List [ ct . TensorType ]] :
257+ def enumerated_shapes_from_compile_specs (
258+ compile_specs : List [CompileSpec ],
259+ ) -> cto . coreml . OpLinearQuantizerConfig :
283260 """
284- Returns the model's ct.inputs by parsing the list of compile specs.
285-
286- Expected JSON schema per entry (as produced by generate_ct_inputs_compile_spec):
287- {
288- "kind": "TensorType",
289- "name": "<non-empty string>",
290- "shape": {
291- "kind": "fixed", "shape": [int, ...] | null
292- | "kind": "enumerated", "shapes": [[int, ...], ...], "default": [int, ...] | null
293- }
294- }
261+ Returns the model's post conversion quantization by parsing the list of compile specs.
295262 """
296- def _is_int_shape (seq ):
297- return isinstance (seq , (list , tuple )) and all (isinstance (x , int ) for x in seq )
298-
299- def _parse_shape (shape_json ):
300- if not isinstance (shape_json , dict ) or "kind" not in shape_json :
301- raise ValueError ("Invalid shape JSON: missing 'kind'" )
302-
303- kind = shape_json ["kind" ]
304-
305- # Case: fixed
306- if kind == "fixed" :
307- shp = shape_json .get ("shape" , None )
308- if shp is None :
309- return None
310- if not _is_int_shape (shp ):
311- raise TypeError ("Fixed shape must be a list/tuple of ints or null" )
312- return tuple (shp )
313-
314- # Case: enumerated
315- if kind == "enumerated" :
316- shapes = shape_json .get ("shapes" , None )
317- if not isinstance (shapes , list ) or not shapes :
318- raise ValueError ("Enumerated shape must have non-empty 'shapes' list" )
319-
320- parsed_shapes = []
321- for s in shapes :
322- if not _is_int_shape (s ):
323- raise TypeError ("EnumeratedShapes entries must be lists of ints" )
324- parsed_shapes .append (ct .Shape (tuple (s )))
325-
326- default = shape_json .get ("default" , None )
327- default_shape = None
328- if default is not None :
329- if not _is_int_shape (default ):
330- raise TypeError ("EnumeratedShapes.default must be a list of ints" )
331- default_shape = ct .Shape (tuple (default ))
332-
333- return ct .EnumeratedShapes (shapes = parsed_shapes , default = default_shape )
334-
335- raise ValueError (f"Unsupported shape kind: { kind } " )
336-
337263 for compile_spec in compile_specs :
338- if compile_spec .key == COMPILE_SPEC_KEYS .CT_INPUTS .value :
339- raw = compile_spec .value .decode ("utf-8" )
340- payload = json .loads (raw )
341-
342- if not isinstance (payload , list ):
343- raise ValueError ("CT_INPUTS payload must be a list" )
344-
345- ct_inputs : List [ct .TensorType ] = []
346- for entry in payload :
347- if not isinstance (entry , dict ) or entry .get ("kind" ) != "TensorType" :
348- raise ValueError ("Each entry must be a dict with kind == 'TensorType'" )
349-
350- name = entry .get ("name" , "" )
351- if not isinstance (name , str ) or not name :
352- raise ValueError ("TensorType.name must be a non-empty string" )
353-
354- shape_json = entry .get ("shape" , None )
355- shape = _parse_shape (shape_json ) if shape_json is not None else None
356-
357- # Per your current contract, dtype/default_value must be None (and were omitted).
358- # So we only pass name + shape here.
359- ct_inputs .append (ct .TensorType (name = name , shape = shape ))
360-
361- return ct_inputs
362-
363- return None
264+ if compile_spec .key == COMPILE_SPEC_KEYS .ENUMERATED_SHAPES .value :
265+ emap_json = compile_spec .value .decode ("utf-8" )
266+ emap = _SymbolicShapeToEnumeratedShapeMap .from_json (emap_json )
267+ return emap
268+ return None
364269
365270 @staticmethod
366271 def generate_compile_specs (
@@ -594,10 +499,29 @@ def preprocess(
594499 op_linear_quantizer_config = (
595500 CoreMLBackend .op_linear_quantizer_config_from_compile_specs (compile_specs )
596501 )
597- enumerated_shapes = (
598- CoreMLBackend . enumerate_shapes_from_compile_specs ( compile_specs )
502+ enumerated_shapes = CoreMLBackend . enumerated_shapes_from_compile_specs (
503+ compile_specs
599504 )
600505
506+ # If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
507+ # explicitly
508+ ct_inputs = None
509+ if enumerated_shapes is not None :
510+ ct_inputs = _get_ct_inputs (edge_program , enumerated_shapes )
511+
512+ # Check there are not multiple enumerated inputs if iOS is below 18
513+ if (minimum_deployment_target is None ) or (
514+ minimum_deployment_target < ct .target .iOS18
515+ ):
516+ n_enumerated_inputs = 0
517+ for ct_in in ct_inputs :
518+ if isinstance (ct_in .shape , ct .EnumeratedShapes ):
519+ n_enumerated_inputs += 1
520+ if n_enumerated_inputs > 1 :
521+ raise ValueError (
522+ f"You're program has { n_enumerated_inputs } , but the minimum_deployment_target is set to { minimum_deployment_target } . Multiple enumerated inputs requires iOS18 or later."
523+ )
524+
601525 # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
602526 # get_compiled_model_path() requires a loaded model.
603527 skip_model_load = model_type != CoreMLBackend .MODEL_TYPE .COMPILED_MODEL
@@ -610,6 +534,7 @@ def preprocess(
610534 compute_precision = model_compute_precision ,
611535 minimum_deployment_target = minimum_deployment_target ,
612536 compute_units = compute_units ,
537+ inputs = ct_inputs ,
613538 )
614539
615540 if op_linear_quantizer_config is not None :
0 commit comments