diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index edf7aa97241..d1614f30451 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -17,6 +17,10 @@ import coremltools as ct import coremltools.optimize as cto from executorch.backends.apple.coreml import executorchcoreml +from executorch.backends.apple.coreml.compiler.enumerated_shape_utils import ( + _get_ct_inputs, + _SymbolicShapeToEnumeratedShapeMap, +) from executorch.backends.apple.coreml.logging import get_coreml_log_level from executorch.exir.backend.backend_details import ( BackendDetails, @@ -37,6 +41,7 @@ class COMPILE_SPEC_KEYS(Enum): MIN_DEPLOYMENT_TARGET = "min_deployment_target" MODEL_COMPUTE_PRECISION = "model_compute_precision" OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config" + ENUMERATED_SHAPES = "enumerated_shapes" class MODEL_PATHS(Enum): @@ -143,7 +148,7 @@ def generate_minimum_deployment_target_compile_spec( @staticmethod def min_deployment_target_from_compile_specs( compile_specs: List[CompileSpec], - ) -> ct.target: + ) -> Optional[ct.target]: """ Returns the minimum deployment target by parsing the list of compile specs. """ @@ -214,6 +219,54 @@ def op_linear_quantizer_config_from_compile_specs( return None + @staticmethod + def generate_enumerated_shapes_compile_spec( + ep: ExportedProgram, + enumerated_shapes: Dict[str, List[List[int]]], + ) -> CompileSpec: + """ + Returns the compile spec representing the model enumerated shapes + enumerated_shapes is a dictionary for each input to its enumerated shapes, e.g., + + enumerated_shapes = { + {"x": [[1, 1, 24], [8, 9, 24]] + {"y": [[1, 6], [30, 6]], + ] + + means the model can handle x can be shape [1, 1, 24] or [8, 9, 24] and y can be shape [1, 6] or [30, 6]. + + Only multiple inputs can have enumerated shapes if using iOS18 or later. + In this case, each input must have the same number of enumerated shapes, and these shapes are tied together + by their order in the list. For example, the model above can handle x with shape [1, 1, 24] and y with shape [1, 6], + or x with shape [8, 9, 24] and y with shape [30, 6], but not x with shape [1, 1, 24] and y with shape [30, 6]. + + Passing incorrect shapes at runtime will result in an error. + """ + emap = _SymbolicShapeToEnumeratedShapeMap.from_exported_program( + ep, + enumerated_shapes, + ) + str_representation = emap.to_json() + byte_representation = str_representation.encode("utf-8") + return CompileSpec( + COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value, + byte_representation, + ) + + @staticmethod + def enumerated_shapes_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> cto.coreml.OpLinearQuantizerConfig: + """ + Returns the model's post conversion quantization by parsing the list of compile specs. + """ + for compile_spec in compile_specs: + if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value: + emap_json = compile_spec.value.decode("utf-8") + emap = _SymbolicShapeToEnumeratedShapeMap.from_json(emap_json) + return emap + return None + @staticmethod def generate_compile_specs( compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL, @@ -446,6 +499,28 @@ def preprocess( op_linear_quantizer_config = ( CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs) ) + enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs( + compile_specs + ) + + # If using enumerated shapes, we need to pass the inputs to CoreML's convert() function + # explicitly + ct_inputs = None + if enumerated_shapes is not None: + ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes) + + # Check there are not multiple enumerated inputs if iOS is below 18 + if (minimum_deployment_target is None) or ( + minimum_deployment_target < ct.target.iOS18 + ): + n_enumerated_inputs = 0 + for ct_in in ct_inputs: + if isinstance(ct_in.shape, ct.EnumeratedShapes): + n_enumerated_inputs += 1 + if n_enumerated_inputs > 1: + raise ValueError( + f"You're program has {n_enumerated_inputs}, but the minimum_deployment_target is set to {minimum_deployment_target}. Multiple enumerated inputs requires iOS18 or later." + ) # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because # get_compiled_model_path() requires a loaded model. @@ -459,6 +534,7 @@ def preprocess( compute_precision=model_compute_precision, minimum_deployment_target=minimum_deployment_target, compute_units=compute_units, + inputs=ct_inputs, ) if op_linear_quantizer_config is not None: diff --git a/backends/apple/coreml/compiler/enumerated_shape_utils.py b/backends/apple/coreml/compiler/enumerated_shape_utils.py new file mode 100644 index 00000000000..663830f702e --- /dev/null +++ b/backends/apple/coreml/compiler/enumerated_shape_utils.py @@ -0,0 +1,244 @@ +import json +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + +import coremltools as ct +import torch +from coremltools.converters.mil.frontend.torch.utils import TORCH_DTYPE_TO_MIL_DTYPE + +_IGNORE_RANGE_CONSTRAINTS: bool = False + + +@dataclass(frozen=True, slots=True) +class _SymInt: + key_name: str + low: Optional[int] + high: Optional[int] + + @classmethod + def from_symint_and_range_constraints(cls, s: torch.SymInt, range_constraints=None): + # Canonicalize: "Sym(s0)" -> "s0", or leave "s0" as is + def _symkey(sym: torch.SymInt) -> str: + s = str(sym) + return s[4:-1] if s.startswith("Sym(") and s.endswith(")") else s + + # Convert symint to int. Infinity is converted to None + def _as_int_or_none(b): + if b is None: + return None + s = str(b) + if s in {"int_oo", "-int_oo", "oo", "-oo", "Infinity", "-Infinity"}: + return None + return int(s) + + # Get low/high from range_constraints if provided + low, high = None, None + if range_constraints is not None: + for k, v in range_constraints.items(): + if _symkey(k) == _symkey(s): + low = _as_int_or_none(getattr(v, "lower", getattr(v, "min", None))) + high = _as_int_or_none(getattr(v, "upper", getattr(v, "max", None))) + return _SymInt(_symkey(s), low, high) + + +@dataclass(frozen=True, slots=True) +class _SymbolicShape: + shape: Tuple[int | _SymInt] + + @classmethod + def from_shape_and_range_constraints(cls, shape, range_constraints=None): + out_shape = [] + for s in shape: + if isinstance(s, int): + assert s >= 0 + out_shape.append(s) + elif isinstance(s, torch.SymInt): + out_shape.append( + _SymInt.from_symint_and_range_constraints(s, range_constraints) + ) + else: + raise ValueError(f"Unexpected type found in shape: {type(s)}") + out_shape = tuple(out_shape) + return _SymbolicShape(out_shape) + + def is_static_shape(self): + for s in self.shape: + if isinstance(s, _SymInt): + return False + return True + + def __len__(self): + return len(self.shape) + + def __getitem__(self, key): + return self.shape[key] + + def to_dict(self): + return asdict(self) + + @classmethod + def from_dict(cls, d): + assert len(d) == 1 and "shape" in d + shape = [] + for s in d["shape"]: + if isinstance(s, int): + shape.append(s) + elif isinstance(s, dict): + assert len(s) == 3 and "key_name" in s and "low" in s and "high" in s + shape.append(_SymInt(**s)) + else: + raise ValueError(f"Unexpected type found in shape: {type(s)}") + shape = tuple(shape) + return _SymbolicShape(shape) + + +def _iterate_over_fake_user_inputs(ep): + user_inputs = ep.graph_signature.user_inputs + for node in ep.graph.nodes: + if node.op == "placeholder" and node.name in user_inputs: + yield (node.name, node.meta["val"]) + + +def _create_enumeration_map(ep, enumerated_shapes, *, ignore_range_constraints=False): + # Each input should have the same number of enumerations + assert len(enumerated_shapes) > 0, "No enumerated shapes provided" + num_enumerations = None + for name, eshapes in enumerated_shapes.items(): + if num_enumerations is None: + num_enumerations = len(eshapes) + else: + assert ( + len(eshapes) > 1 + ), f"Input {name} only has {len(eshapes)} enumerated shapes provided. You should not specify enumerated shapes for inputs with only 1 input." + assert ( + len(eshapes) == num_enumerations + ), f"Input {name} has {len(eshapes)} enumerated shape provided, but other inputs have {num_enumerations} enumerated shapes" + + symbolic_shape_to_enumerations = {} + for name, fake_input in _iterate_over_fake_user_inputs(ep): + shape = fake_input.shape + serialized_shape = _SymbolicShape.from_shape_and_range_constraints( + shape, ep.range_constraints if not ignore_range_constraints else None + ) + if serialized_shape.is_static_shape(): + continue + # Shape is dynamic + if name not in enumerated_shapes: + raise ValueError( + f"The input {name} has a symbolic shape, but you did not provide an enumeration for it" + ) + # Validate + for eshape in enumerated_shapes[name]: + assert len(serialized_shape) == len( + eshape + ), f"In {name}, the rank of the enumeration is {len(eshape)}, but the symbolic shape has rank {len(serialized_shape)}" + for i in range(len(eshape)): + assert isinstance( + eshape[i], int + ), f"Enumerated shapes must be ints, but got {type(eshape[i])}." + assert eshape[i] >= 1, "Each enumerated shape dimension must be >= 1" + if isinstance(serialized_shape[i], int): + assert ( + serialized_shape[i] == eshape[i] + ), f"In {name}, the shape enumeration {eshape} does not match {shape} on the non-symbolic value at index {i}" + else: + # Check eshape is within bound + if serialized_shape[i].low is not None: + # We add special case for when the low bound is 2. This is because Torch does not usually allow 1 as a lower bound + assert (eshape[i] >= serialized_shape[i].low) or ( + eshape[i] == 1 and serialized_shape[i].low == 2 + ), f"In {name}, the shape enumeration {eshape} violates the lower range-constraint on the symbolic shape {shape} at index {i}" + if serialized_shape[i].high is not None: + assert ( + eshape[i] <= serialized_shape[i].high + ), f"In {name}, the shape enumeration {eshape} violates the upper range-constraint on the symbolic shape {shape} at index {i}" + if serialized_shape in symbolic_shape_to_enumerations: + enumerations, names = symbolic_shape_to_enumerations[serialized_shape] + assert ( + enumerations == enumerated_shapes[name] + ), f"The symbolic shape {shape}, has multiple enumerations defined. A new enumeration is defined for input {name}, but the existing inputs {names} have a different one defined. If these inputs have different enumerations, they should be exported with different symbolic shapes." + names.append(name) + symbolic_shape_to_enumerations[serialized_shape] = (enumerations, names) + else: + symbolic_shape_to_enumerations[serialized_shape] = ( + enumerated_shapes[name], + [name], + ) + return symbolic_shape_to_enumerations + + +class _SymbolicShapeToEnumeratedShapeMap: + def __init__(self, emap): + self.emap = emap + + def to_json(self): + json_list = [] + for k in self.emap: + json_list.append((k.to_dict(), self.emap[k])) + return json.dumps(json_list) + + @classmethod + def from_json(cls, s): + emap = {} + json_list = json.loads(s) + for k, v in json_list: + k = _SymbolicShape.from_dict(k) + emap[k] = tuple(v) + return cls(emap) + + @classmethod + def from_exported_program( + cls, + ep, + enumerated_shapes, + *, + ignore_range_constraints=_IGNORE_RANGE_CONSTRAINTS, + ): + emap = _create_enumeration_map( + ep, enumerated_shapes, ignore_range_constraints=ignore_range_constraints + ) + return cls(emap) + + def __getitem__(self, key: _SymbolicShape): + return self.emap[key][0] + + def __contains__(self, key): + return key in self.emap + + def __repr__(self): + return f"_SymbolicShapeToEnumeratedShapeMap(emap={self.emap})" + + +def _get_ct_inputs(ep, emap: _SymbolicShapeToEnumeratedShapeMap): + ct_inputs = [] + for name, fake_input in _iterate_over_fake_user_inputs(ep): + + # CoreML can do funny conversions in ct.convert (e.g., int64 -> int32, int16 -> int32), so here + # we restrict users to use dtypes we know are supported + _ENUMERATED_SHAPE_INPUT_DTYPES = [torch.float16, torch.float32, torch.int32] + for dtype in _ENUMERATED_SHAPE_INPUT_DTYPES: + assert dtype in TORCH_DTYPE_TO_MIL_DTYPE + assert ( + fake_input.dtype in _ENUMERATED_SHAPE_INPUT_DTYPES + ), f"When using enumerated shapes, all inputs must have one of the following dtyeps {_ENUMERATED_SHAPE_INPUT_DTYPES}, but {name} has dtype {fake_input.dtype}" + + ct_dtype = TORCH_DTYPE_TO_MIL_DTYPE[fake_input.dtype] + shape = fake_input.shape + serializable_shape = _SymbolicShape.from_shape_and_range_constraints( + shape, ep.range_constraints if not _IGNORE_RANGE_CONSTRAINTS else None + ) + if serializable_shape.is_static_shape(): + ct_inputs.append( + ct.TensorType(name=name, shape=serializable_shape.shape, dtype=ct_dtype) + ) + continue + # Dynamic shape + assert ( + serializable_shape in emap + ), f"The shape of input {name} ({serializable_shape}) is not in the _SymbolicShapeToEnumeratedShapeMap={emap}" + enumerations = emap[serializable_shape] + ct_enumerated_shape = ct.EnumeratedShapes(shapes=enumerations) + ct_inputs.append( + ct.TensorType(name=name, shape=ct_enumerated_shape, dtype=ct_dtype) + ) + return ct_inputs diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index 93506e6d985..6b3a73599af 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -10,6 +10,9 @@ import torch from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.compiler.coreml_preprocess import ( + COMPILE_SPEC_KEYS, +) from executorch.backends.apple.coreml.logging import get_coreml_log_level from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -192,6 +195,13 @@ def __init__( if skip_ops_for_coreml_delegation is None: skip_ops_for_coreml_delegation = [] self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation + + for compile_spec in compile_specs or []: + if compile_spec.key == COMPILE_SPEC_KEYS.ENUMERATED_SHAPES.value: + assert ( + lower_full_graph + ), "lower_full_graph must be True in the CoreMLPartitioner when using an enumerated shape compile spec" + self.delegation_spec = DelegationSpec( backend_id=CoreMLBackend.__name__, compile_specs=compile_specs if compile_specs is not None else [], diff --git a/backends/apple/coreml/test/test_enumerated_shapes.py b/backends/apple/coreml/test/test_enumerated_shapes.py new file mode 100644 index 00000000000..d7d209c4ce4 --- /dev/null +++ b/backends/apple/coreml/test/test_enumerated_shapes.py @@ -0,0 +1,112 @@ +# Copyright © 2023 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +import unittest + +import coremltools as ct + +import executorch.exir + +import torch + +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.backends.apple.coreml.test.test_coreml_utils import ( + IS_VALID_TEST_RUNTIME, +) + +if IS_VALID_TEST_RUNTIME: + from executorch.runtime import Runtime + + +class TestEnumeratedShapes(unittest.TestCase): + def _compare_outputs(self, executorch_program, eager_program, example_inputs): + if not IS_VALID_TEST_RUNTIME: + return + runtime = Runtime.get() + program = runtime.load_program(executorch_program.buffer) + method = program.load_method("forward") + et_outputs = method.execute(example_inputs)[0] + eager_outputs = eager_program(*example_inputs) + self.assertTrue( + torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02) + ) + + def test_e2e(self): + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 5) + self.linear2 = torch.nn.Linear(11, 5) + + def forward(self, x, y): + return self.linear1(x).sum() + self.linear2(y) + + model = Model() + example_inputs = ( + torch.randn((4, 6, 10)), + torch.randn((5, 11)), + ) + enumerated_shapes = {"x": [[1, 5, 10], [4, 6, 10]], "y": [[3, 11], [5, 11]]} + dynamic_shapes = [ + { + 0: torch.export.Dim.AUTO(min=1, max=4), + 1: torch.export.Dim.AUTO(min=5, max=6), + }, + {0: torch.export.Dim.AUTO(min=3, max=5)}, + ] + ep = torch.export.export( + model.eval(), example_inputs, dynamic_shapes=dynamic_shapes + ) + + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18 + ) + compile_specs.append( + CoreMLBackend.generate_enumerated_shapes_compile_spec( + ep, + enumerated_shapes, + ) + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, lower_full_graph=True + ) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + ) + et_prog = delegated_program.to_executorch() + + good_input1 = ( + torch.randn((1, 5, 10)), + torch.randn((3, 11)), + ) + good_input2 = ( + torch.randn((4, 6, 10)), + torch.randn((5, 11)), + ) + bad_input = ( + torch.randn((1, 5, 10)), + torch.randn((5, 11)), + ) + bad_input2 = ( + torch.randn((2, 7, 12)), + torch.randn((3, 11)), + ) + + self._compare_outputs(et_prog, model, good_input1) + self._compare_outputs(et_prog, model, good_input2) + if IS_VALID_TEST_RUNTIME: + self.assertRaises( + RuntimeError, lambda: self._compare_outputs(et_prog, model, bad_input) + ) + self.assertRaises( + RuntimeError, lambda: self._compare_outputs(et_prog, model, bad_input2) + ) + + +if __name__ == "__main__": + test_runner = TestEnumeratedShapes() + test_runner.test_e2e() diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 0d6b581ee72..4fdbfdd8f21 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -158,10 +158,6 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) - @unittest.skipIf( - not hasattr(torch.version, "git_version"), - "Enable in fbcode once D79658061 lands", - ) def test_dequantize_codebook_linear(self): model, example_inputs = self._get_test_model() quantize_( @@ -189,10 +185,6 @@ def test_dequantize_codebook_linear(self): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) - @unittest.skipIf( - not hasattr(torch.version, "git_version"), - "Enable in fbcode once D79658061 lands", - ) def test_dequantize_codebook_embedding(self): model, example_inputs = self._get_test_model() quantize_(