-
Notifications
You must be signed in to change notification settings - Fork 748
Update CoreML docs #13120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update CoreML docs #13120
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,9 @@ The CoreML partitioner API allows for configuration of the model delegation to C | |
| - `skip_ops_for_coreml_delegation`: Allows you to skip ops for delegation by CoreML. By default, all ops that CoreML supports will be delegated. See [here](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/test/test_coreml_partitioner.py#L42) for an example of skipping an op for delegation. | ||
| - `compile_specs`: A list of `CompileSpec`s for the CoreML backend. These control low-level details of CoreML delegation, such as the compute unit (CPU, GPU, ANE), the iOS deployment target, and the compute precision (FP16, FP32). These are discussed more below. | ||
| - `take_over_mutable_buffer`: A boolean that indicates whether PyTorch mutable buffers in stateful models should be converted to [CoreML `MLState`](https://developer.apple.com/documentation/coreml/mlstate). If set to `False`, mutable buffers in the PyTorch graph are converted to graph inputs and outputs to the CoreML lowered module under the hood. Generally, setting `take_over_mutable_buffer` to true will result in better performance, but using `MLState` requires iOS >= 18.0, macOS >= 15.0, and Xcode >= 16.0. | ||
| - `take_over_constant_data`: A boolean that indicates whether PyTorch constant data like model weights should be consumed by the CoreML delegate. If set to False, constant data is passed to the CoreML delegate as inputs. By deafault, take_over_constant_data=True. | ||
| - `lower_full_graph`: A boolean that indicates whether the entire graph must be lowered to CoreML. If set to True and CoreML does not support an op, an error is raised during lowering. If set to False and CoreML does not support an op, the op is executed on the CPU by ExecuTorch. Although setting `lower_full_graph`=False can allow a model to lower where it would otherwise fail, it can introduce performance overhead in the model when there are unsupported ops. You will see warnings about unsupported ops during lowering if there are any. By default, `lower_full_graph`=False. | ||
|
|
||
|
|
||
| #### CoreML CompileSpec | ||
|
|
||
|
|
@@ -70,10 +73,14 @@ A list of `CompileSpec`s is constructed with [`CoreMLBackend.generate_compile_sp | |
| - `coremltools.ComputeUnit.CPU_ONLY` (uses the CPU only) | ||
| - `coremltools.ComputeUnit.CPU_AND_GPU` (uses both the CPU and GPU, but not the ANE) | ||
| - `coremltools.ComputeUnit.CPU_AND_NE` (uses both the CPU and ANE, but not the GPU) | ||
| - `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). The default value is `coremltools.target.iOS15`. | ||
| - `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). By default, the smallest deployment target needed to deploy the model is selected. During export, you will see a warning about the "CoreML specification version" that was used for the model, which maps onto a deployment target as discussed [here](https://apple.github.io/coremltools/mlmodel/Format/Model.html#model). If you need to control the deployment target, please specify it explicitly. | ||
| - `compute_precision`: The compute precision used by CoreML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the CoreML backend by default. Also note that the ANE only supports FP16 precision. | ||
| - `model_type`: Whether the model should be compiled to the CoreML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([`CoreMLBackend.MODEL_TYPE.COMPILED_MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([`CoreMLBackend.MODEL_TYPE.MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using `CoreMLBackend.MODEL_TYPE.COMPILED_MODEL` and doing compilation ahead of time should improve the first time on-device model load time. | ||
|
|
||
| #### Backward compatibility | ||
|
|
||
| CoreML supports backward compatibility via the `minimum_deployment_target` option. A model exported with a specific deployment target is guaranteed to work on all deployment targets >= the specified deployment target. For example, a model exported with `coremltools.target.iOS17` will work on iOS 17 or higher. | ||
|
|
||
| ### Testing the Model | ||
|
|
||
| After generating the CoreML-delegated .pte, the model can be tested from Python using the ExecuTorch runtime Python bindings. This can be used to quickly check the model and evaluate numerical accuracy. See [Testing the Model](using-executorch-export.md#testing-the-model) for more information. | ||
|
|
@@ -173,6 +180,66 @@ Quantizing activations requires calibrating the model on representative data. A | |
|
|
||
| See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information. | ||
|
|
||
| ### LLM quantization with quantize_ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @metascroy Is there a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CoreML should select the required minimum_deployment target automatically. For PT2E, it should select iOS17. But for quantize_, I noticed it was only working for iOS18 now (need to investigate further): #13122 In terms of how do we enforce it: it should work automatically for PT2E, but let me know if it doesn't. For quantize_, I'll try to make it work automatically, but as an intermediate stop-gap, we can explicitly set to iOS18 if quantize_ is used in the recipe. |
||
|
|
||
| The CoreML backend also supports quantizing models with the [torchao](https://github.com/pytorch/ao) quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with CoreML: | ||
|
|
||
| * Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity) | ||
| * Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes) | ||
|
|
||
| Below is an example that quantizes embeddings to 8-bits per-axis and linear layers to 4-bits using group_size=32 with affine quantization: | ||
|
|
||
| ```python | ||
| from torchao.quantization.granularity import PerGroup, PerAxis | ||
| from torchao.quantization.quant_api import ( | ||
| IntxWeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
|
|
||
| # Quantize embeddings with 8-bits, per channel | ||
| embedding_config = IntxWeightOnlyConfig( | ||
| weight_dtype=torch.int8, | ||
| granularity=PerAxis(0), | ||
| ) | ||
| qunatize_( | ||
| eager_model, | ||
| lambda m, fqn: isinstance(m, torch.nn.Embedding), | ||
abhinaykukkadapu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| # Quantize linear layers with 4-bits, per-group | ||
| linear_config = IntxWeightOnlyConfig( | ||
| weight_dtype=torch.int4, | ||
| granularity=PerGroup(32), | ||
| ) | ||
| quantize_( | ||
| eager_model, | ||
| linear_config, | ||
| ) | ||
| ``` | ||
|
|
||
| Below is another example that uses codebook quantization to quantize both embeddings and linear layers to 3-bits. | ||
| In the coremltools documentation, this is called [palettization](https://apple.github.io/coremltools/docs-guides/source/opt-palettization-overview.html): | ||
|
|
||
| ``` | ||
| from torchao.quantization.quant_api import ( | ||
| quantize_, | ||
| ) | ||
| from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig | ||
|
|
||
| quant_config = CodebookWeightOnlyConfig( | ||
| dtype=torch.uint3, | ||
| # There is one LUT per 16 columns | ||
| block_size=[-1, 16], | ||
| ) | ||
|
|
||
| quantize_( | ||
| eager_model, | ||
| quant_config, | ||
| lambda m, fqn: isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear), | ||
| ) | ||
| ``` | ||
|
|
||
| Both of the above examples will export and lower to CoreML with the to_edge_transform_and_lower API. | ||
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob question, it seems we publish the default as FLOAT16 in the
generate_compile_specsfunction, what happens when a quantizer, would the backend ignores this, or is it upto the user to make sure there is nocompute_precisionin compile specs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even for a quantized model, there is a compute precision. Compute precision controls the precision of the non-quantized ops in the model.