Skip to content

Commit 32446f2

Browse files
committed
init
1 parent efe4756 commit 32446f2

File tree

5 files changed

+419
-144
lines changed

5 files changed

+419
-144
lines changed

backends/apple/coreml/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runtime.python_library(
1818
srcs = glob([
1919
"compiler/*.py",
2020
"logging.py",
21+
"enumerated_shape_utils.py",
2122
]),
2223
visibility = [
2324
"@EXECUTORCH_CLIENTS",

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 61 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
import coremltools as ct
1818
import coremltools.optimize as cto
1919
from executorch.backends.apple.coreml import executorchcoreml
20-
from numpy import isin
20+
from executorch.backends.apple.coreml.enumerated_shape_utils import (
21+
_get_ct_inputs,
22+
_SymbolicShapeToEnumeratedShapeMap,
23+
)
2124
from executorch.backends.apple.coreml.logging import get_coreml_log_level
2225
from executorch.exir.backend.backend_details import (
2326
BackendDetails,
@@ -38,7 +41,7 @@ class COMPILE_SPEC_KEYS(Enum):
3841
MIN_DEPLOYMENT_TARGET = "min_deployment_target"
3942
MODEL_COMPUTE_PRECISION = "model_compute_precision"
4043
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
41-
CT_INPUTS = "ct_inputs"
44+
ENUMERATED_SHAPES = "enumerated_shapes"
4245

4346

4447
class MODEL_PATHS(Enum):
@@ -145,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec(
145148
@staticmethod
146149
def min_deployment_target_from_compile_specs(
147150
compile_specs: List[CompileSpec],
148-
) -> ct.target:
151+
) -> Optional[ct.target]:
149152
"""
150153
Returns the minimum deployment target by parsing the list of compile specs.
151154
"""
@@ -216,151 +219,53 @@ def op_linear_quantizer_config_from_compile_specs(
216219

217220
return None
218221

219-
220222
@staticmethod
221-
def generate_ct_inputs_compile_spec(
222-
ct_inputs: List[ct.TensorType],
223+
def generate_enumerated_shapes_compile_spec(
224+
ep: ExportedProgram,
225+
enumerated_shapes: Dict[str, List[List[int]]],
223226
) -> CompileSpec:
224227
"""
225-
Returns the compile spec representing the model inputs
226-
Generally this is not needed, but is used to specify things that cannot be inferred from
227-
the exported program, like enumerated shapes
228-
"""
228+
Returns the compile spec representing the model enumerated shapes
229+
enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g.,
229230
230-
def _is_int_shape(seq):
231-
return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq)
232-
233-
def _serialize_shape(shape):
234-
# Case 1: None
235-
if shape is None:
236-
return {"kind": "fixed", "shape": None}
237-
238-
# Case 2: Plain list/tuple of ints
239-
if _is_int_shape(shape):
240-
return {"kind": "fixed", "shape": list(shape)}
241-
242-
# Case 3: EnumeratedShapes (with ct.Shape entries)
243-
if isinstance(shape, ct.EnumeratedShapes):
244-
shapes = []
245-
for s in shape.shapes:
246-
# ct.Shape(...) -> s.shape should be a tuple of ints
247-
if not _is_int_shape(s.shape):
248-
raise TypeError("EnumeratedShapes entries must be tuples/lists of ints")
249-
shapes.append(list(s.shape))
250-
default = None
251-
if shape.default is not None:
252-
if not _is_int_shape(shape.default.shape):
253-
raise TypeError("EnumeratedShapes.default must be a tuple/list of ints")
254-
default = list(shape.default.shape)
255-
return {"kind": "enumerated", "shapes": shapes, "default": default}
256-
257-
# Anything else is out of scope for now
258-
raise TypeError("Shape must be EnumeratedShapes, a list/tuple of ints, or None")
259-
260-
def tensor_type_to_dict(t: ct.TensorType):
261-
assert isinstance(t, ct.TensorType)
262-
for attr in ["name", "dtype", "default_value"]:
263-
assert getattr(t, attr) is None, f"{attr} cannot be given a value"
264-
return {
265-
"kind": "TensorType",
266-
"name": t.name,
267-
"shape": _serialize_shape(t.shape),
268-
"dtype": t.dtype,
269-
"default_value": t.default_value,
270-
}
231+
enumerated_shapes = {
232+
{"x": [[1, 1, 24], [8, 9, 24]]
233+
{"y": [[1, 6], [30, 6]],
234+
]
235+
236+
means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6].
271237
272-
str_representation = json.dumps([tensor_type_to_dict(ct_in) for ct_in in ct_inputs])
238+
Only multiple inputs can have enumerated shapes if using iOS18 or later.
239+
In this case, each input must have the same number of enumerated shapes, and these shapes are tied together
240+
by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6],
241+
or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6].
242+
243+
Passing incorrect shapes at runtime will result in an error.
244+
"""
245+
emap = _SymbolicShapeToEnumeratedShapeMap.from_exported_program(
246+
ep,
247+
enumerated_shapes,
248+
)
249+
str_representation = emap.to_json()
273250
byte_representation = str_representation.encode("utf-8")
274251
return CompileSpec(
275-
COMPILE_SPEC_KEYS.CT_INPUTS.value,
252+
COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value,
276253
byte_representation,
277254
)
278255

279256
@staticmethod
280-
def ct_inputs_from_compile_specs(
281-
compile_specs: List["CompileSpec"],
282-
) -> Optional[List[ct.TensorType]]:
257+
def enumerated_shapes_from_compile_specs(
258+
compile_specs: List[CompileSpec],
259+
) -> cto.coreml.OpLinearQuantizerConfig:
283260
"""
284-
Returns the model's ct.inputs by parsing the list of compile specs.
285-
286-
Expected JSON schema per entry (as produced by generate_ct_inputs_compile_spec):
287-
{
288-
"kind": "TensorType",
289-
"name": "<non-empty string>",
290-
"shape": {
291-
"kind": "fixed", "shape": [int, ...] | null
292-
| "kind": "enumerated", "shapes": [[int, ...], ...], "default": [int, ...] | null
293-
}
294-
}
261+
Returns the model's post conversion quantization by parsing the list of compile specs.
295262
"""
296-
def _is_int_shape(seq):
297-
return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq)
298-
299-
def _parse_shape(shape_json):
300-
if not isinstance(shape_json, dict) or "kind" not in shape_json:
301-
raise ValueError("Invalid shape JSON: missing 'kind'")
302-
303-
kind = shape_json["kind"]
304-
305-
# Case: fixed
306-
if kind == "fixed":
307-
shp = shape_json.get("shape", None)
308-
if shp is None:
309-
return None
310-
if not _is_int_shape(shp):
311-
raise TypeError("Fixed shape must be a list/tuple of ints or null")
312-
return tuple(shp)
313-
314-
# Case: enumerated
315-
if kind == "enumerated":
316-
shapes = shape_json.get("shapes", None)
317-
if not isinstance(shapes, list) or not shapes:
318-
raise ValueError("Enumerated shape must have non-empty 'shapes' list")
319-
320-
parsed_shapes = []
321-
for s in shapes:
322-
if not _is_int_shape(s):
323-
raise TypeError("EnumeratedShapes entries must be lists of ints")
324-
parsed_shapes.append(ct.Shape(tuple(s)))
325-
326-
default = shape_json.get("default", None)
327-
default_shape = None
328-
if default is not None:
329-
if not _is_int_shape(default):
330-
raise TypeError("EnumeratedShapes.default must be a list of ints")
331-
default_shape = ct.Shape(tuple(default))
332-
333-
return ct.EnumeratedShapes(shapes=parsed_shapes, default=default_shape)
334-
335-
raise ValueError(f"Unsupported shape kind: {kind}")
336-
337263
for compile_spec in compile_specs:
338-
if compile_spec.key == COMPILE_SPEC_KEYS.CT_INPUTS.value:
339-
raw = compile_spec.value.decode("utf-8")
340-
payload = json.loads(raw)
341-
342-
if not isinstance(payload, list):
343-
raise ValueError("CT_INPUTS payload must be a list")
344-
345-
ct_inputs: List[ct.TensorType] = []
346-
for entry in payload:
347-
if not isinstance(entry, dict) or entry.get("kind") != "TensorType":
348-
raise ValueError("Each entry must be a dict with kind == 'TensorType'")
349-
350-
name = entry.get("name", "")
351-
if not isinstance(name, str) or not name:
352-
raise ValueError("TensorType.name must be a non-empty string")
353-
354-
shape_json = entry.get("shape", None)
355-
shape = _parse_shape(shape_json) if shape_json is not None else None
356-
357-
# Per your current contract, dtype/default_value must be None (and were omitted).
358-
# So we only pass name + shape here.
359-
ct_inputs.append(ct.TensorType(name=name, shape=shape))
360-
361-
return ct_inputs
362-
363-
return None
264+
if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value:
265+
emap_json = compile_spec.value.decode("utf-8")
266+
emap = _SymbolicShapeToEnumeratedShapeMap.from_json(emap_json)
267+
return emap
268+
return None
364269

365270
@staticmethod
366271
def generate_compile_specs(
@@ -594,10 +499,29 @@ def preprocess(
594499
op_linear_quantizer_config = (
595500
CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
596501
)
597-
enumerated_shapes = (
598-
CoreMLBackend.enumerate_shapes_from_compile_specs(compile_specs)
502+
enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs(
503+
compile_specs
599504
)
600505

506+
# If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
507+
# explicitly
508+
ct_inputs = None
509+
if enumerated_shapes is not None:
510+
ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes)
511+
512+
# Check there are not multiple enumerated inputs if iOS is below 18
513+
if (minimum_deployment_target is None) or (
514+
minimum_deployment_target < ct.target.iOS18
515+
):
516+
n_enumerated_inputs = 0
517+
for ct_in in ct_inputs:
518+
if isinstance(ct_in.shape, ct.EnumeratedShapes):
519+
n_enumerated_inputs += 1
520+
if n_enumerated_inputs > 1:
521+
raise ValueError(
522+
f"You're program has {n_enumerated_inputs}, but the minimum_deployment_target is set to {minimum_deployment_target}. Multiple enumerated inputs requires iOS18 or later."
523+
)
524+
601525
# Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
602526
# get_compiled_model_path() requires a loaded model.
603527
skip_model_load = model_type != CoreMLBackend.MODEL_TYPE.COMPILED_MODEL
@@ -610,6 +534,7 @@ def preprocess(
610534
compute_precision=model_compute_precision,
611535
minimum_deployment_target=minimum_deployment_target,
612536
compute_units=compute_units,
537+
inputs=ct_inputs,
613538
)
614539

615540
if op_linear_quantizer_config is not None:

0 commit comments

Comments
 (0)