Skip to content

Commit a1fce54

Browse files
committed
Update CoreML docs
1 parent ec35f56 commit a1fce54

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

docs/source/backends-coreml.md

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ The CoreML partitioner API allows for configuration of the model delegation to C
6161
- `skip_ops_for_coreml_delegation`: Allows you to skip ops for delegation by CoreML. By default, all ops that CoreML supports will be delegated. See [here](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/test/test_coreml_partitioner.py#L42) for an example of skipping an op for delegation.
6262
- `compile_specs`: A list of `CompileSpec`s for the CoreML backend. These control low-level details of CoreML delegation, such as the compute unit (CPU, GPU, ANE), the iOS deployment target, and the compute precision (FP16, FP32). These are discussed more below.
6363
- `take_over_mutable_buffer`: A boolean that indicates whether PyTorch mutable buffers in stateful models should be converted to [CoreML `MLState`](https://developer.apple.com/documentation/coreml/mlstate). If set to `False`, mutable buffers in the PyTorch graph are converted to graph inputs and outputs to the CoreML lowered module under the hood. Generally, setting `take_over_mutable_buffer` to true will result in better performance, but using `MLState` requires iOS >= 18.0, macOS >= 15.0, and Xcode >= 16.0.
64+
- `take_over_constant_data`: A boolean that indicates whether PyTorch constant data like model weights should be consumed by the CoreML delegate. If set to False, constant data is passed to the CoreML delegate as inputs. By deafault, take_over_constant_data=True.
65+
- `lower_full_graph`: A boolean that indicates whether the entire graph must be lowered to CoreML. If set to True and CoreML does not support an op, an error is raised during lowering. If set to False and CoreML does not support an op, the op is executed on the CPU by ExecuTorch. Although setting `lower_full_graph`=False can allow a model to lower where it would otherwise fail, it can introduce performance overhead in the model when there are unsupported ops. You will see warnings about unsupported ops during lowering if there are any. By default, `lower_full_graph`=False.
66+
6467

6568
#### CoreML CompileSpec
6669

@@ -70,10 +73,14 @@ A list of `CompileSpec`s is constructed with [`CoreMLBackend.generate_compile_sp
7073
- `coremltools.ComputeUnit.CPU_ONLY` (uses the CPU only)
7174
- `coremltools.ComputeUnit.CPU_AND_GPU` (uses both the CPU and GPU, but not the ANE)
7275
- `coremltools.ComputeUnit.CPU_AND_NE` (uses both the CPU and ANE, but not the GPU)
73-
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). The default value is `coremltools.target.iOS15`.
76+
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). By default, the smallest deployment target needed to deploy the model is selected. During export, you will see a warning about the "CoreML specification version" that was used for the model, which maps onto a deployment target as discussed [here](https://apple.github.io/coremltools/mlmodel/Format/Model.html#model). If you need to control the deployment target, please specify it explicitly.
7477
- `compute_precision`: The compute precision used by CoreML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the CoreML backend by default. Also note that the ANE only supports FP16 precision.
7578
- `model_type`: Whether the model should be compiled to the CoreML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([`CoreMLBackend.MODEL_TYPE.COMPILED_MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([`CoreMLBackend.MODEL_TYPE.MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using `CoreMLBackend.MODEL_TYPE.COMPILED_MODEL` and doing compilation ahead of time should improve the first time on-device model load time.
7679

80+
#### Backward compatibility
81+
82+
CoreML supports backward compatibility via the `minimum_deployment_target` option. A model exported with a specific deployment target is guaranteed to work on all deployment targets >= the specified deployment target. For example, a model exported with `coremltools.target.iOS17` will work on iOS 17 or higher.
83+
7784
### Testing the Model
7885

7986
After generating the CoreML-delegated .pte, the model can be tested from Python using the ExecuTorch runtime Python bindings. This can be used to quickly check the model and evaluate numerical accuracy. See [Testing the Model](using-executorch-export.md#testing-the-model) for more information.
@@ -173,6 +180,66 @@ Quantizing activations requires calibrating the model on representative data. A
173180

174181
See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information.
175182

183+
### LLM quantization with quantize_
184+
185+
The CoreML backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with CoreML:
186+
187+
* Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity)
188+
* Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes)
189+
190+
Below is an example that quantizes embeddings to 8-bits per-axis and linear layers to 4-bits using group_size=32 with affine quantization:
191+
192+
```python
193+
from torchao.quantization.granularity import PerGroup, PerAxis
194+
from torchao.quantization.quant_api import (
195+
IntxWeightOnlyConfig,
196+
quantize_,
197+
)
198+
199+
# Quantize embeddings with 8-bits, per channel
200+
embedding_config = IntxWeightOnlyConfig(
201+
weight_dtype=torch.int8,
202+
granularity=PerAxis(0),
203+
)
204+
qunatize_(
205+
eager_model,
206+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
207+
)
208+
209+
# Quantize linear layers with 4-bits, per-group
210+
linear_config = IntxWeightOnlyConfig(
211+
weight_dtype=torch.int4,
212+
granularity=PerGroup(32),
213+
)
214+
quantize_(
215+
eager_model,
216+
linear_config,
217+
)
218+
```
219+
220+
Below is another example that uses codebook quantization to quantize both embeddings and linear layers to 3-bits.
221+
In the coremltools documentation, this is called [palettization](https://apple.github.io/coremltools/docs-guides/source/opt-palettization-overview.html):
222+
223+
```
224+
from torchao.quantization.quant_api import (
225+
quantize_,
226+
)
227+
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
228+
229+
quant_config = CodebookWeightOnlyConfig(
230+
dtype=torch.uint3,
231+
# There is one LUT per 16 columns
232+
block_size=[-1, 16],
233+
)
234+
235+
quantize_(
236+
eager_model,
237+
quant_config,
238+
lambda m, fqn: isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear),
239+
)
240+
```
241+
242+
Both of the above examples will export and lower to CoreML with the to_edge_transform_and_lower API.
176243

177244
----
178245

0 commit comments

Comments
 (0)