Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/apple/coreml/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runtime.python_library(
srcs = glob([
"compiler/*.py",
"logging.py",
"enumerated_shape_utils.py",
]),
visibility = [
"@EXECUTORCH_CLIENTS",
Expand Down
78 changes: 77 additions & 1 deletion backends/apple/coreml/compiler/coreml_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import coremltools as ct
import coremltools.optimize as cto
from executorch.backends.apple.coreml import executorchcoreml
from executorch.backends.apple.coreml.enumerated_shape_utils import (
_get_ct_inputs,
_SymbolicShapeToEnumeratedShapeMap,
)
from executorch.backends.apple.coreml.logging import get_coreml_log_level
from executorch.exir.backend.backend_details import (
BackendDetails,
Expand All @@ -37,6 +41,7 @@ class COMPILE_SPEC_KEYS(Enum):
MIN_DEPLOYMENT_TARGET = "min_deployment_target"
MODEL_COMPUTE_PRECISION = "model_compute_precision"
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
ENUMERATED_SHAPES = "enumerated_shapes"


class MODEL_PATHS(Enum):
Expand Down Expand Up @@ -143,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec(
@staticmethod
def min_deployment_target_from_compile_specs(
compile_specs: List[CompileSpec],
) -> ct.target:
) -> Optional[ct.target]:
"""
Returns the minimum deployment target by parsing the list of compile specs.
"""
Expand Down Expand Up @@ -214,6 +219,54 @@ def op_linear_quantizer_config_from_compile_specs(

return None

@staticmethod
def generate_enumerated_shapes_compile_spec(
ep: ExportedProgram,
enumerated_shapes: Dict[str, List[List[int]]],
) -> CompileSpec:
"""
Returns the compile spec representing the model enumerated shapes
enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g.,

enumerated_shapes = {
{"x": [[1, 1, 24], [8, 9, 24]]
{"y": [[1, 6], [30, 6]],
]

means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6].

Only multiple inputs can have enumerated shapes if using iOS18 or later.
In this case, each input must have the same number of enumerated shapes, and these shapes are tied together
by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6],
or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6].

Passing incorrect shapes at runtime will result in an error.
"""
emap = _SymbolicShapeToEnumeratedShapeMap.from_exported_program(
ep,
enumerated_shapes,
)
str_representation = emap.to_json()
byte_representation = str_representation.encode("utf-8")
return CompileSpec(
COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value,
byte_representation,
)

@staticmethod
def enumerated_shapes_from_compile_specs(
compile_specs: List[CompileSpec],
) -> cto.coreml.OpLinearQuantizerConfig:
"""
Returns the model's post conversion quantization by parsing the list of compile specs.
"""
for compile_spec in compile_specs:
if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value:
emap_json = compile_spec.value.decode("utf-8")
emap = _SymbolicShapeToEnumeratedShapeMap.from_json(emap_json)
return emap
return None

@staticmethod
def generate_compile_specs(
compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL,
Expand Down Expand Up @@ -446,6 +499,28 @@ def preprocess(
op_linear_quantizer_config = (
CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
)
enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs(
compile_specs
)

# If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
# explicitly
ct_inputs = None
if enumerated_shapes is not None:
ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes)

# Check there are not multiple enumerated inputs if iOS is below 18
if (minimum_deployment_target is None) or (
minimum_deployment_target < ct.target.iOS18
):
n_enumerated_inputs = 0
for ct_in in ct_inputs:
if isinstance(ct_in.shape, ct.EnumeratedShapes):
n_enumerated_inputs += 1
if n_enumerated_inputs > 1:
raise ValueError(
f"You're program has {n_enumerated_inputs}, but the minimum_deployment_target is set to {minimum_deployment_target}. Multiple enumerated inputs requires iOS18 or later."
)

# Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
# get_compiled_model_path() requires a loaded model.
Expand All @@ -459,6 +534,7 @@ def preprocess(
compute_precision=model_compute_precision,
minimum_deployment_target=minimum_deployment_target,
compute_units=compute_units,
inputs=ct_inputs,
)

if op_linear_quantizer_config is not None:
Expand Down
233 changes: 233 additions & 0 deletions backends/apple/coreml/enumerated_shape_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Tuple

import coremltools as ct
import torch
from coremltools.converters.mil.frontend.torch.utils import TORCH_DTYPE_TO_MIL_DTYPE

_IGNORE_RANGE_CONSTRAINTS: bool = True


@dataclass(frozen=True, slots=True)
class _SymInt:
key_name: str
low: Optional[int]
high: Optional[int]

@classmethod
def from_symint_and_range_constraints(cls, s: torch.SymInt, range_constraints=None):
# Canonicalize: "Sym(s0)" -> "s0", or leave "s0" as is
def _symkey(sym: torch.SymInt) -> str:
s = str(sym)
return s[4:-1] if s.startswith("Sym(") and s.endswith(")") else s

# Convert symint to int. Infinity is converted to None
def _as_int_or_none(b):
if b is None:
return None
s = str(b)
if s in {"int_oo", "-int_oo", "oo", "-oo", "Infinity", "-Infinity"}:
return None
return int(s)

# Get low/high from range_constraints if provided
low, high = None, None
if range_constraints is not None:
for k, v in range_constraints.items():
if _symkey(k) == _symkey(s):
low = _as_int_or_none(getattr(v, "lower", getattr(v, "min", None)))
high = _as_int_or_none(getattr(v, "upper", getattr(v, "max", None)))
return _SymInt(_symkey(s), low, high)


@dataclass(frozen=True, slots=True)
class _SymbolicShape:
shape: Tuple[int | _SymInt]

@classmethod
def from_shape_and_range_constraints(cls, shape, range_constraints=None):
out_shape = []
for s in shape:
if isinstance(s, int):
assert s >= 0
out_shape.append(s)
elif isinstance(s, torch.SymInt):
out_shape.append(
_SymInt.from_symint_and_range_constraints(s, range_constraints)
)
else:
raise ValueError(f"Unexpected type found in shape: {type(s)}")
out_shape = tuple(out_shape)
return _SymbolicShape(out_shape)

def is_static_shape(self):
for s in self.shape:
if isinstance(s, _SymInt):
return False
return True

def __len__(self):
return len(self.shape)

def __getitem__(self, key):
return self.shape[key]

def to_dict(self):
return asdict(self)

@classmethod
def from_dict(cls, d):
assert len(d) == 1 and "shape" in d
shape = []
for s in d["shape"]:
if isinstance(s, int):
shape.append(s)
elif isinstance(s, dict):
assert len(s) == 3 and "key_name" in s and "low" in s and "high" in s
shape.append(_SymInt(**s))
else:
raise ValueError(f"Unexpected type found in shape: {type(s)}")
shape = tuple(shape)
return _SymbolicShape(shape)


def _iterate_over_fake_user_inputs(ep):
user_inputs = ep.graph_signature.user_inputs
for node in ep.graph.nodes:
if node.op == "placeholder" and node.name in user_inputs:
yield (node.name, node.meta["val"])


def _create_enumeration_map(ep, enumerated_shapes, *, ignore_range_constraints=False):
# Each input should have the same number of enumerations
assert (
len({len(v) for v in enumerated_shapes.values()}) == 1
), "Each input with enumerated shapes must have the same number of enumerated shapes"
symbolic_shape_to_enumerations = {}
for name, fake_input in _iterate_over_fake_user_inputs(ep):
shape = fake_input.shape
serialized_shape = _SymbolicShape.from_shape_and_range_constraints(
shape, ep.range_constraints if not ignore_range_constraints else None
)
if serialized_shape.is_static_shape():
continue
# Shape is dynamic
if name not in enumerated_shapes:
raise ValueError(
f"The input {name} has a symbolic shape, but you did not provide an enumeration for it"
)
# Validate
for eshape in enumerated_shapes[name]:
assert len(serialized_shape) == len(
eshape
), f"In {name}, the rank of the enumeration is {len(eshape)}, but the symbolic shape has rank {len(serialized_shape)}"
for i in range(len(eshape)):
assert isinstance(
eshape[i], int
), f"Enumerated shapes must be ints, but got {type(eshape[i])}."
assert eshape[i] >= 1, "Each enumerated shape dimension must be >= 1"
if isinstance(serialized_shape[i], int):
assert (
serialized_shape[i] == eshape[i]
), f"In {name}, the shape enumeration {eshape} does not match {shape} on the non-symbolic value at index {i}"
else:
# Check eshape is within bound
if serialized_shape[i].low is not None:
assert (
eshape[i] >= serialized_shape[i].low
), f"In {name}, the shape enumeration {eshape} violates the lower range-constraint on the symbolic shape {shape} at index {i}"
if serialized_shape[i].high is not None:
assert (
eshape[i] <= serialized_shape[i].high
), f"In {name}, the shape enumeration {eshape} violates the upper range-constraint on the symbolic shape {shape} at index {i}"
if serialized_shape in symbolic_shape_to_enumerations:
enumerations, names = symbolic_shape_to_enumerations[serialized_shape]
assert (
enumerations == enumerated_shapes[name]
), f"The symbolic shape {shape}, has multiple enumerations defined. A new enumeration is defined for input {name}, but the existing inputs {names} have a different one defined. If these inputs have different enumerations, they should be exported with different symbolic shapes."
names.append(name)
symbolic_shape_to_enumerations[serialized_shape] = (enumerations, names)
else:
symbolic_shape_to_enumerations[serialized_shape] = (
enumerated_shapes[name],
[name],
)
return symbolic_shape_to_enumerations


class _SymbolicShapeToEnumeratedShapeMap:
def __init__(self, emap):
self.emap = emap

def to_json(self):
json_list = []
for k in self.emap:
json_list.append((k.to_dict(), self.emap[k]))
return json.dumps(json_list)

@classmethod
def from_json(cls, s):
emap = {}
json_list = json.loads(s)
for k, v in json_list:
k = _SymbolicShape.from_dict(k)
emap[k] = tuple(v)
return cls(emap)

@classmethod
def from_exported_program(
cls,
ep,
enumerated_shapes,
*,
ignore_range_constraints=_IGNORE_RANGE_CONSTRAINTS,
):
emap = _create_enumeration_map(
ep, enumerated_shapes, ignore_range_constraints=ignore_range_constraints
)
return cls(emap)

def __getitem__(self, key: _SymbolicShape):
return self.emap[key][0]

def __contains__(self, key):
return key in self.emap

def __repr__(self):
return f"_SymbolicShapeToEnumeratedShapeMap(emap={self.emap})"


def _get_ct_inputs(ep, emap: _SymbolicShapeToEnumeratedShapeMap):
ct_inputs = []
for name, fake_input in _iterate_over_fake_user_inputs(ep):

# CoreML can do funny conversions in ct.convert (e.g., int64 -> int32, int16 -> int32), so here
# we restrict users to use dtypes we know are supported
_ENUMERATED_SHAPE_INPUT_DTYPES = [torch.float16, torch.float32, torch.int32]
for dtype in _ENUMERATED_SHAPE_INPUT_DTYPES:
assert dtype in TORCH_DTYPE_TO_MIL_DTYPE
assert (
fake_input.dtype in _ENUMERATED_SHAPE_INPUT_DTYPES
), f"When using enumerated shapes, all inputs must have one of the following dtyeps {_ENUMERATED_SHAPE_INPUT_DTYPES}, but {name} has dtype {fake_input.dtype}"

ct_dtype = TORCH_DTYPE_TO_MIL_DTYPE[fake_input.dtype]
shape = fake_input.shape
serializable_shape = _SymbolicShape.from_shape_and_range_constraints(
shape, ep.range_constraints if not _IGNORE_RANGE_CONSTRAINTS else None
)
if serializable_shape.is_static_shape():
ct_inputs.append(
ct.TensorType(name=name, shape=serializable_shape.shape, dtype=ct_dtype)
)
continue
# Dynamic shape
assert (
serializable_shape in emap
), f"The shape of input {name} ({serializable_shape}) is not in the _SymbolicShapeToEnumeratedShapeMap={emap}"
enumerations = emap[serializable_shape]
ct_enumerated_shape = ct.EnumeratedShapes(shapes=enumerations)
ct_inputs.append(
ct.TensorType(name=name, shape=ct_enumerated_shape, dtype=ct_dtype)
)
return ct_inputs
Loading
Loading