|
1 | | -# Core ML Backend |
| 1 | +# CoreML Backend |
2 | 2 |
|
3 | | -Core ML delegate is the ExecuTorch solution to take advantage of Apple's [CoreML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With CoreML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE). |
| 3 | +CoreML delegate is the ExecuTorch solution to take advantage of Apple's [CoreML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With CoreML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE). |
4 | 4 |
|
5 | 5 | ## Features |
6 | 6 |
|
@@ -77,7 +77,98 @@ A list of `CompileSpec`s is constructed with [`CoreMLBackend.generate_compile_sp |
77 | 77 | - `compute_precision`: The compute precision used by CoreML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the CoreML backend by default. Also note that the ANE only supports FP16 precision. |
78 | 78 | - `model_type`: Whether the model should be compiled to the CoreML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([`CoreMLBackend.MODEL_TYPE.COMPILED_MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([`CoreMLBackend.MODEL_TYPE.MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using `CoreMLBackend.MODEL_TYPE.COMPILED_MODEL` and doing compilation ahead of time should improve the first time on-device model load time. |
79 | 79 |
|
80 | | -#### Backward compatibility |
| 80 | +### Dynamic and Enumerated Shapes in CoreML Export |
| 81 | + |
| 82 | +When exporting an `ExportedProgram` to CoreML, **dynamic shapes** are mapped to [`RangeDim`](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#set-the-range-for-each-dimension). |
| 83 | +This enables CoreML `.pte` files to accept inputs with varying dimensions at runtime. |
| 84 | + |
| 85 | +⚠️ **Note:** The Apple Neural Engine (ANE) does not support true dynamic shapes. If a model relies on `RangeDim`, CoreML will fall back to scheduling the model on the CPU or GPU instead of the ANE. |
| 86 | + |
| 87 | +--- |
| 88 | + |
| 89 | +#### Enumerated Shapes |
| 90 | + |
| 91 | +To enable limited flexibility on the ANE—and often achieve better performance overall—you can export models using **[enumerated shapes](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#select-from-predetermined-shapes)**. |
| 92 | + |
| 93 | +- Enumerated shapes are *not fully dynamic*. |
| 94 | +- Instead, they define a **finite set of valid input shapes** that CoreML can select from at runtime. |
| 95 | +- This approach allows some adaptability while still preserving ANE compatibility. |
| 96 | + |
| 97 | +--- |
| 98 | + |
| 99 | +#### Specifying Enumerated Shapes |
| 100 | + |
| 101 | +Unlike `RangeDim`, **enumerated shapes are not part of the `ExportedProgram` itself.** |
| 102 | +They must be provided through a compile spec. |
| 103 | + |
| 104 | +For reference on how to do this, see: |
| 105 | +- The annotated code snippet below, and |
| 106 | +- The [end-to-end test in ExecuTorch](https://github.com/pytorch/executorch/blob/main/backends/apple/coreml/test/test_enumerated_shapes.py), which demonstrates how to specify enumerated shapes during export. |
| 107 | + |
| 108 | + |
| 109 | +```python |
| 110 | +class Model(torch.nn.Module): |
| 111 | + def __init__(self): |
| 112 | + super().__init__() |
| 113 | + self.linear1 = torch.nn.Linear(10, 5) |
| 114 | + self.linear2 = torch.nn.Linear(11, 5) |
| 115 | + |
| 116 | + def forward(self, x, y): |
| 117 | + return self.linear1(x).sum() + self.linear2(y) |
| 118 | + |
| 119 | +model = Model() |
| 120 | +example_inputs = ( |
| 121 | + torch.randn((4, 6, 10)), |
| 122 | + torch.randn((5, 11)), |
| 123 | +) |
| 124 | + |
| 125 | +# Specify the enumerated shapes. Below we specify that: |
| 126 | +# |
| 127 | +# * x can take shape [1, 5, 10] and y can take shape [3, 11], or |
| 128 | +# * x can take shape [4, 6, 10] and y can take shape [5, 11] |
| 129 | +# |
| 130 | +# Any other input shapes will result in a runtime error. |
| 131 | +# |
| 132 | +# Note that we must export x and y with dynamic shapes in the ExportedProgram |
| 133 | +# because some of their dimensions are dynamic |
| 134 | +enumerated_shapes = {"x": [[1, 5, 10], [4, 6, 10]], "y": [[3, 11], [5, 11]]} |
| 135 | +dynamic_shapes = [ |
| 136 | + { |
| 137 | + 0: torch.export.Dim.AUTO(min=1, max=4), |
| 138 | + 1: torch.export.Dim.AUTO(min=5, max=6), |
| 139 | + }, |
| 140 | + {0: torch.export.Dim.AUTO(min=3, max=5)}, |
| 141 | +] |
| 142 | +ep = torch.export.export( |
| 143 | + model.eval(), example_inputs, dynamic_shapes=dynamic_shapes |
| 144 | +) |
| 145 | + |
| 146 | +# If enumerated shapes are specified for multiple inputs, we must export |
| 147 | +# for iOS18+ |
| 148 | +compile_specs = CoreMLBackend.generate_compile_specs( |
| 149 | + minimum_deployment_target=ct.target.iOS18 |
| 150 | +) |
| 151 | +compile_specs.append( |
| 152 | + CoreMLBackend.generate_enumerated_shapes_compile_spec( |
| 153 | + ep, |
| 154 | + enumerated_shapes, |
| 155 | + ) |
| 156 | +) |
| 157 | + |
| 158 | +# When using an enumerated shape compile spec, you must specify lower_full_graph=True |
| 159 | +# in the CoreMLPartitioner. We do not support using enumerated shapes |
| 160 | +# for partially exported models |
| 161 | +partitioner = CoreMLPartitioner( |
| 162 | + compile_specs=compile_specs, lower_full_graph=True |
| 163 | +) |
| 164 | +delegated_program = executorch.exir.to_edge_transform_and_lower( |
| 165 | + ep, |
| 166 | + partitioner=[partitioner], |
| 167 | +) |
| 168 | +et_prog = delegated_program.to_executorch() |
| 169 | +``` |
| 170 | + |
| 171 | +### Backward compatibility |
81 | 172 |
|
82 | 173 | CoreML supports backward compatibility via the `minimum_deployment_target` option. A model exported with a specific deployment target is guaranteed to work on all deployment targets >= the specified deployment target. For example, a model exported with `coremltools.target.iOS17` will work on iOS 17 or higher. |
83 | 174 |
|
@@ -184,8 +275,8 @@ See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/ma |
184 | 275 |
|
185 | 276 | The CoreML backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with CoreML: |
186 | 277 |
|
187 | | -* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity) |
188 | | -* Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes) |
| 278 | +* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity). Using 4-bit or PerGroup quantization requires exporting with minimum_deployment_target >= ct.target.iOS18. Using 8-bit quantization with per-axis granularity is supported on ct.target.IOS16+. See [CoreML `CompileSpec`](#coreml-compilespec) for more information on setting the deployment target. |
| 279 | +* Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes). Quantizing with CodebookWeightOnlyConfig requires exporting with minimum_deployment_target >= ct.target.iOS18, see [CoreML `CompileSpec`](#coreml-compilespec) for more information on setting the deployment target. |
189 | 280 |
|
190 | 281 | Below is an example that quantizes embeddings to 8-bits per-axis and linear layers to 4-bits using group_size=32 with affine quantization: |
191 | 282 |
|
@@ -229,8 +320,8 @@ from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyCon |
229 | 320 |
|
230 | 321 | quant_config = CodebookWeightOnlyConfig( |
231 | 322 | dtype=torch.uint3, |
232 | | - # There is one LUT per 16 columns |
233 | | - block_size=[-1, 16], |
| 323 | + # There is one LUT per 16 rows |
| 324 | + block_size=[16, -1], |
234 | 325 | ) |
235 | 326 |
|
236 | 327 | quantize_( |
@@ -293,6 +384,6 @@ This happens because the model is in FP16, but CoreML interprets some of the arg |
293 | 384 | If you're using Python 3.13, try reducing your python version to Python 3.12. coremltools does not support Python 3.13 per [coremltools issue #2487](https://github.com/apple/coremltools/issues/2487). |
294 | 385 |
|
295 | 386 | ### At runtime |
296 | | -1. [ETCoreMLModelCompiler.mm:55] [Core ML] Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescription=Failed to par$ |
| 387 | +1. [ETCoreMLModelCompiler.mm:55] [CoreML] Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescription=Failed to par$ |
297 | 388 |
|
298 | 389 | This means the model requires the the CoreML opset 'CoreML7', which requires running the model on iOS >= 17 or macOS >= 14. |
0 commit comments