From 53b5d6fc494219dc1fb6711378b9fd05dcbd5027 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 30 Jun 2025 11:57:20 -0700 Subject: [PATCH] Update backends-coreml.md (#12100) Adds more context on quantization to address https://github.com/pytorch/executorch/issues/12059. (cherry picked from commit 488dfed4bacd2e7b13e8dbba7cf6e248c4bd6979) --- docs/source/backends-coreml.md | 44 +++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/docs/source/backends-coreml.md b/docs/source/backends-coreml.md index 37c89b56a54..9820b7fe9b1 100644 --- a/docs/source/backends-coreml.md +++ b/docs/source/backends-coreml.md @@ -112,18 +112,17 @@ mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFA sample_inputs = (torch.randn(1, 3, 224, 224), ) # Step 1: Define a LinearQuantizerConfig and create an instance of a CoreMLQuantizer -quantization_config = ct.optimize.torch.quantization.LinearQuantizerConfig.from_dict( - { - "global_config": { - "quantization_scheme": ct.optimize.torch.quantization.QuantizationScheme.symmetric, - "milestones": [0, 0, 10, 10], - "activation_dtype": torch.quint8, - "weight_dtype": torch.qint8, - "weight_per_channel": True, - } - } +# Note that linear here does not mean only linear layers are quantized, but that linear (aka affine) quantization +# is being performed +static_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=torch.quint8, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) ) -quantizer = CoreMLQuantizer(quantization_config) +quantizer = CoreMLQuantizer(static_8bit_config) # Step 2: Export the model for training training_gm = torch.export.export_for_training(mobilenet_v2, sample_inputs).module() @@ -153,6 +152,24 @@ et_program = to_edge_transform_and_lower( ).to_executorch() ``` +The above does static quantization (activations and weights are quantized). Quantizing activations requires calibrating the model on representative data. You can also do weight-only quantization, which does not require calibration data, by specifying the activation_dtype to be torch.float32: + +``` +weight_only_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=torch.float32, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) +) +quantizer = CoreMLQuantizer(weight_only_8bit_config) +prepared_model = prepare_pt2e(training_gm, quantizer) +quantized_model = convert_pt2e(prepared_model) +``` + +Note that static quantization requires exporting the model for iOS17 or later. + See [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html) for more information. ---- @@ -204,3 +221,8 @@ This happens because the model is in FP16, but CoreML interprets some of the arg raise RuntimeError("BlobWriter not loaded") If you're using Python 3.13, try reducing your python version to Python 3.12. coremltools does not support Python 3.13, see this [issue](https://github.com/apple/coremltools/issues/2487). + +### At runtime +1. [ETCoreMLModelCompiler.mm:55] [Core ML] Failed to compile model, error = Error Domain=com.apple.mlassetio Code=1 "Failed to parse the model specification. Error: Unable to parse ML Program: at unknown location: Unknown opset 'CoreML7'." UserInfo={NSLocalizedDescription=Failed to par$ + +This means the model requires the the CoreML opset 'CoreML7', which requires running the model on iOS17/macOS14 or later.