1717import coremltools as ct
1818import coremltools .optimize as cto
1919from executorch .backends .apple .coreml import executorchcoreml
20+ from executorch .backends .apple .coreml .compiler .enumerated_shape_utils import (
21+ _get_ct_inputs ,
22+ _SymbolicShapeToEnumeratedShapeMap ,
23+ )
2024from executorch .backends .apple .coreml .logging import get_coreml_log_level
2125from executorch .exir .backend .backend_details import (
2226 BackendDetails ,
@@ -37,6 +41,7 @@ class COMPILE_SPEC_KEYS(Enum):
3741 MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3842 MODEL_COMPUTE_PRECISION = "model_compute_precision"
3943 OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
44+ ENUMERATED_SHAPES = "enumerated_shapes"
4045
4146
4247class MODEL_PATHS (Enum ):
@@ -143,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec(
143148 @staticmethod
144149 def min_deployment_target_from_compile_specs (
145150 compile_specs : List [CompileSpec ],
146- ) -> ct .target :
151+ ) -> Optional [ ct .target ] :
147152 """
148153 Returns the minimum deployment target by parsing the list of compile specs.
149154 """
@@ -214,6 +219,54 @@ def op_linear_quantizer_config_from_compile_specs(
214219
215220 return None
216221
222+ @staticmethod
223+ def generate_enumerated_shapes_compile_spec (
224+ ep : ExportedProgram ,
225+ enumerated_shapes : Dict [str , List [List [int ]]],
226+ ) -> CompileSpec :
227+ """
228+ Returns the compile spec representing the model enumerated shapes
229+ enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g.,
230+
231+ enumerated_shapes = {
232+ {"x": [[1, 1, 24], [8, 9, 24]]
233+ {"y": [[1, 6], [30, 6]],
234+ ]
235+
236+ means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6].
237+
238+ Only multiple inputs can have enumerated shapes if using iOS18 or later.
239+ In this case, each input must have the same number of enumerated shapes, and these shapes are tied together
240+ by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6],
241+ or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6].
242+
243+ Passing incorrect shapes at runtime will result in an error.
244+ """
245+ emap = _SymbolicShapeToEnumeratedShapeMap .from_exported_program (
246+ ep ,
247+ enumerated_shapes ,
248+ )
249+ str_representation = emap .to_json ()
250+ byte_representation = str_representation .encode ("utf-8" )
251+ return CompileSpec (
252+ COMPILE_SPEC_KEYS .ENUMERATED_SHAPES .value ,
253+ byte_representation ,
254+ )
255+
256+ @staticmethod
257+ def enumerated_shapes_from_compile_specs (
258+ compile_specs : List [CompileSpec ],
259+ ) -> cto .coreml .OpLinearQuantizerConfig :
260+ """
261+ Returns the model's post conversion quantization by parsing the list of compile specs.
262+ """
263+ for compile_spec in compile_specs :
264+ if compile_spec .key == COMPILE_SPEC_KEYS .ENUMERATED_SHAPES .value :
265+ emap_json = compile_spec .value .decode ("utf-8" )
266+ emap = _SymbolicShapeToEnumeratedShapeMap .from_json (emap_json )
267+ return emap
268+ return None
269+
217270 @staticmethod
218271 def generate_compile_specs (
219272 compute_unit : ct .ComputeUnit = ct .ComputeUnit .ALL ,
@@ -446,6 +499,28 @@ def preprocess(
446499 op_linear_quantizer_config = (
447500 CoreMLBackend .op_linear_quantizer_config_from_compile_specs (compile_specs )
448501 )
502+ enumerated_shapes = CoreMLBackend .enumerated_shapes_from_compile_specs (
503+ compile_specs
504+ )
505+
506+ # If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
507+ # explicitly
508+ ct_inputs = None
509+ if enumerated_shapes is not None :
510+ ct_inputs = _get_ct_inputs (edge_program , enumerated_shapes )
511+
512+ # Check there are not multiple enumerated inputs if iOS is below 18
513+ if (minimum_deployment_target is None ) or (
514+ minimum_deployment_target < ct .target .iOS18
515+ ):
516+ n_enumerated_inputs = 0
517+ for ct_in in ct_inputs :
518+ if isinstance (ct_in .shape , ct .EnumeratedShapes ):
519+ n_enumerated_inputs += 1
520+ if n_enumerated_inputs > 1 :
521+ raise ValueError (
522+ f"You're program has { n_enumerated_inputs } , but the minimum_deployment_target is set to { minimum_deployment_target } . Multiple enumerated inputs requires iOS18 or later."
523+ )
449524
450525 # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
451526 # get_compiled_model_path() requires a loaded model.
@@ -459,6 +534,7 @@ def preprocess(
459534 compute_precision = model_compute_precision ,
460535 minimum_deployment_target = minimum_deployment_target ,
461536 compute_units = compute_units ,
537+ inputs = ct_inputs ,
462538 )
463539
464540 if op_linear_quantizer_config is not None :
0 commit comments