|
6 | 6 | from typing import Any, Optional, Sequence
|
7 | 7 |
|
8 | 8 | import coremltools as ct
|
| 9 | +import torch |
9 | 10 |
|
10 | 11 | from executorch.backends.apple.coreml.compiler import CoreMLBackend
|
11 | 12 | from executorch.backends.apple.coreml.partition.coreml_partitioner import (
|
|
18 | 19 |
|
19 | 20 | from executorch.exir import EdgeCompileConfig
|
20 | 21 | from executorch.export import (
|
| 22 | + AOQuantizationConfig, |
21 | 23 | BackendRecipeProvider,
|
22 | 24 | ExportRecipe,
|
23 | 25 | LoweringRecipe,
|
| 26 | + QuantizationRecipe, |
24 | 27 | RecipeType,
|
25 | 28 | )
|
| 29 | +from torchao.quantization.granularity import PerAxis, PerGroup |
| 30 | +from torchao.quantization.quant_api import IntxWeightOnlyConfig |
26 | 31 |
|
27 | 32 |
|
28 | 33 | class CoreMLRecipeProvider(BackendRecipeProvider):
|
@@ -50,66 +55,315 @@ def create_recipe(
|
50 | 55 | # Validate kwargs
|
51 | 56 | self._validate_recipe_kwargs(recipe_type, **kwargs)
|
52 | 57 |
|
53 |
| - # Parse recipe type to get precision and compute unit |
54 |
| - precision = None |
55 | 58 | if recipe_type == CoreMLRecipeType.FP32:
|
56 |
| - precision = ct.precision.FLOAT32 |
| 59 | + return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs) |
57 | 60 | elif recipe_type == CoreMLRecipeType.FP16:
|
58 |
| - precision = ct.precision.FLOAT16 |
59 |
| - |
60 |
| - if precision is None: |
61 |
| - raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") |
| 61 | + return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs) |
| 62 | + elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC: |
| 63 | + return self._build_pt2e_quantized_recipe( |
| 64 | + recipe_type, activation_dtype=torch.quint8, **kwargs |
| 65 | + ) |
| 66 | + elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY: |
| 67 | + return self._build_pt2e_quantized_recipe( |
| 68 | + recipe_type, activation_dtype=torch.float32, **kwargs |
| 69 | + ) |
| 70 | + elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL: |
| 71 | + return self._build_torchao_quantized_recipe( |
| 72 | + recipe_type, |
| 73 | + weight_dtype=torch.int4, |
| 74 | + is_per_channel=True, |
| 75 | + **kwargs, |
| 76 | + ) |
| 77 | + elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP: |
| 78 | + group_size = kwargs.pop("group_size", 32) |
| 79 | + return self._build_torchao_quantized_recipe( |
| 80 | + recipe_type, |
| 81 | + weight_dtype=torch.int4, |
| 82 | + is_per_channel=False, |
| 83 | + group_size=group_size, |
| 84 | + **kwargs, |
| 85 | + ) |
| 86 | + elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL: |
| 87 | + return self._build_torchao_quantized_recipe( |
| 88 | + recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs |
| 89 | + ) |
| 90 | + elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP: |
| 91 | + group_size = kwargs.pop("group_size", 32) |
| 92 | + return self._build_torchao_quantized_recipe( |
| 93 | + recipe_type, |
| 94 | + weight_dtype=torch.int8, |
| 95 | + is_per_channel=False, |
| 96 | + group_size=group_size, |
| 97 | + **kwargs, |
| 98 | + ) |
| 99 | + elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
| 100 | + bits = kwargs.pop("bits", 3) |
| 101 | + block_size = kwargs.pop("block_size", [-1, 16]) |
| 102 | + return self._build_codebook_quantized_recipe( |
| 103 | + recipe_type, bits=bits, block_size=block_size, **kwargs |
| 104 | + ) |
62 | 105 |
|
63 |
| - return self._build_recipe(recipe_type, precision, **kwargs) |
| 106 | + return None |
64 | 107 |
|
65 | 108 | def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
|
66 |
| - if not kwargs: |
67 |
| - return |
68 |
| - expected_keys = {"minimum_deployment_target", "compute_unit"} |
| 109 | + """Validate kwargs for each recipe type""" |
| 110 | + expected_keys = self._get_expected_keys(recipe_type) |
| 111 | + |
69 | 112 | unexpected = set(kwargs.keys()) - expected_keys
|
70 | 113 | if unexpected:
|
71 | 114 | raise ValueError(
|
72 |
| - f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " |
73 |
| - f"Unexpected parameters: {list(unexpected)}" |
| 115 | + f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" |
74 | 116 | )
|
| 117 | + |
| 118 | + self._validate_base_parameters(kwargs) |
| 119 | + self._validate_group_size_parameter(recipe_type, kwargs) |
| 120 | + self._validate_codebook_parameters(recipe_type, kwargs) |
| 121 | + |
| 122 | + def _get_expected_keys(self, recipe_type: RecipeType) -> set: |
| 123 | + """Get expected parameter keys for a recipe type""" |
| 124 | + common_keys = {"minimum_deployment_target", "compute_unit"} |
| 125 | + |
| 126 | + if recipe_type in [ |
| 127 | + CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP, |
| 128 | + CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP, |
| 129 | + ]: |
| 130 | + return common_keys | {"group_size", "filter_fn"} |
| 131 | + elif recipe_type in [ |
| 132 | + CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL, |
| 133 | + CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL, |
| 134 | + ]: |
| 135 | + return common_keys | {"filter_fn"} |
| 136 | + elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
| 137 | + return common_keys | {"bits", "block_size", "filter_fn"} |
| 138 | + else: |
| 139 | + return common_keys |
| 140 | + |
| 141 | + def _validate_base_parameters(self, kwargs: Any) -> None: |
| 142 | + """Validate minimum_deployment_target and compute_unit parameters""" |
75 | 143 | if "minimum_deployment_target" in kwargs:
|
76 | 144 | minimum_deployment_target = kwargs["minimum_deployment_target"]
|
77 | 145 | if not isinstance(minimum_deployment_target, ct.target):
|
78 | 146 | raise ValueError(
|
79 | 147 | f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
|
80 | 148 | )
|
| 149 | + |
81 | 150 | if "compute_unit" in kwargs:
|
82 | 151 | compute_unit = kwargs["compute_unit"]
|
83 | 152 | if not isinstance(compute_unit, ct.ComputeUnit):
|
84 | 153 | raise ValueError(
|
85 | 154 | f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
|
86 | 155 | )
|
87 | 156 |
|
88 |
| - def _build_recipe( |
| 157 | + def _validate_group_size_parameter( |
| 158 | + self, recipe_type: RecipeType, kwargs: Any |
| 159 | + ) -> None: |
| 160 | + """Validate group_size parameter for applicable recipe types""" |
| 161 | + if ( |
| 162 | + recipe_type |
| 163 | + in [ |
| 164 | + CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP, |
| 165 | + CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP, |
| 166 | + ] |
| 167 | + and "group_size" in kwargs |
| 168 | + ): |
| 169 | + group_size = kwargs["group_size"] |
| 170 | + if not isinstance(group_size, int): |
| 171 | + raise ValueError( |
| 172 | + f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" |
| 173 | + ) |
| 174 | + if group_size <= 0: |
| 175 | + raise ValueError( |
| 176 | + f"Parameter 'group_size' must be positive, got: {group_size}" |
| 177 | + ) |
| 178 | + |
| 179 | + def _validate_codebook_parameters( |
| 180 | + self, recipe_type: RecipeType, kwargs: Any |
| 181 | + ) -> None: |
| 182 | + """Validate bits and block_size parameters for codebook recipe type""" |
| 183 | + if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
| 184 | + return |
| 185 | + |
| 186 | + if "bits" in kwargs: |
| 187 | + bits = kwargs["bits"] |
| 188 | + if not isinstance(bits, int): |
| 189 | + raise ValueError( |
| 190 | + f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}" |
| 191 | + ) |
| 192 | + if not (1 <= bits <= 8): |
| 193 | + raise ValueError( |
| 194 | + f"Parameter 'bits' must be between 1 and 8, got: {bits}" |
| 195 | + ) |
| 196 | + |
| 197 | + if "block_size" in kwargs: |
| 198 | + block_size = kwargs["block_size"] |
| 199 | + if not isinstance(block_size, list): |
| 200 | + raise ValueError( |
| 201 | + f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}" |
| 202 | + ) |
| 203 | + |
| 204 | + def _validate_and_set_deployment_target( |
| 205 | + self, kwargs: Any, min_target: ct.target, quantization_type: str |
| 206 | + ) -> None: |
| 207 | + """Validate or set minimum deployment target for quantization recipes""" |
| 208 | + minimum_deployment_target = kwargs.get("minimum_deployment_target", None) |
| 209 | + if minimum_deployment_target and minimum_deployment_target < min_target: |
| 210 | + raise ValueError( |
| 211 | + f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization" |
| 212 | + ) |
| 213 | + else: |
| 214 | + # Default to the minimum target for this quantization type |
| 215 | + kwargs["minimum_deployment_target"] = min_target |
| 216 | + |
| 217 | + def _build_fp_recipe( |
89 | 218 | self,
|
90 | 219 | recipe_type: RecipeType,
|
91 | 220 | precision: ct.precision,
|
92 | 221 | **kwargs: Any,
|
93 | 222 | ) -> ExportRecipe:
|
| 223 | + """Build FP32/FP16 recipe""" |
94 | 224 | lowering_recipe = self._get_coreml_lowering_recipe(
|
95 | 225 | compute_precision=precision,
|
96 | 226 | **kwargs,
|
97 | 227 | )
|
98 | 228 |
|
99 | 229 | return ExportRecipe(
|
100 | 230 | name=recipe_type.value,
|
101 |
| - quantization_recipe=None, # TODO - add quantization recipe |
| 231 | + lowering_recipe=lowering_recipe, |
| 232 | + ) |
| 233 | + |
| 234 | + def _build_pt2e_quantized_recipe( |
| 235 | + self, |
| 236 | + recipe_type: RecipeType, |
| 237 | + activation_dtype: torch.dtype, |
| 238 | + **kwargs: Any, |
| 239 | + ) -> ExportRecipe: |
| 240 | + """Build PT2E-based quantization recipe""" |
| 241 | + from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer |
| 242 | + |
| 243 | + self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e") |
| 244 | + |
| 245 | + # Validate activation_dtype |
| 246 | + assert activation_dtype in [ |
| 247 | + torch.quint8, |
| 248 | + torch.float32, |
| 249 | + ], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}" |
| 250 | + |
| 251 | + # Create quantization config |
| 252 | + config = ct.optimize.torch.quantization.LinearQuantizerConfig( |
| 253 | + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( |
| 254 | + quantization_scheme="symmetric", |
| 255 | + activation_dtype=activation_dtype, |
| 256 | + weight_dtype=torch.qint8, |
| 257 | + weight_per_channel=True, |
| 258 | + ) |
| 259 | + ) |
| 260 | + |
| 261 | + quantizer = CoreMLQuantizer(config) |
| 262 | + quantization_recipe = QuantizationRecipe(quantizers=[quantizer]) |
| 263 | + |
| 264 | + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
| 265 | + |
| 266 | + return ExportRecipe( |
| 267 | + name=recipe_type.value, |
| 268 | + quantization_recipe=quantization_recipe, |
| 269 | + lowering_recipe=lowering_recipe, |
| 270 | + ) |
| 271 | + |
| 272 | + def _build_torchao_quantized_recipe( |
| 273 | + self, |
| 274 | + recipe_type: RecipeType, |
| 275 | + weight_dtype: torch.dtype, |
| 276 | + is_per_channel: bool, |
| 277 | + group_size: int = 32, |
| 278 | + **kwargs: Any, |
| 279 | + ) -> ExportRecipe: |
| 280 | + """Build TorchAO-based quantization recipe""" |
| 281 | + if is_per_channel: |
| 282 | + weight_granularity = PerAxis(axis=0) |
| 283 | + else: |
| 284 | + weight_granularity = PerGroup(group_size=group_size) |
| 285 | + |
| 286 | + # Use user-provided filter_fn if provided |
| 287 | + filter_fn = kwargs.get("filter_fn", None) |
| 288 | + config = AOQuantizationConfig( |
| 289 | + ao_base_config=IntxWeightOnlyConfig( |
| 290 | + weight_dtype=weight_dtype, |
| 291 | + granularity=weight_granularity, |
| 292 | + ), |
| 293 | + filter_fn=filter_fn, |
| 294 | + ) |
| 295 | + |
| 296 | + quantization_recipe = QuantizationRecipe( |
| 297 | + quantizers=None, |
| 298 | + ao_quantization_configs=[config], |
| 299 | + ) |
| 300 | + |
| 301 | + # override minimum_deployment_target to ios18 for torchao (GH issue #13122) |
| 302 | + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") |
| 303 | + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
| 304 | + |
| 305 | + return ExportRecipe( |
| 306 | + name=recipe_type.value, |
| 307 | + quantization_recipe=quantization_recipe, |
| 308 | + lowering_recipe=lowering_recipe, |
| 309 | + ) |
| 310 | + |
| 311 | + def _build_codebook_quantized_recipe( |
| 312 | + self, |
| 313 | + recipe_type: RecipeType, |
| 314 | + bits: int, |
| 315 | + block_size: list, |
| 316 | + **kwargs: Any, |
| 317 | + ) -> ExportRecipe: |
| 318 | + """Build codebook/palettization quantization recipe""" |
| 319 | + from torchao.prototype.quantization.codebook_coreml import ( |
| 320 | + CodebookWeightOnlyConfig, |
| 321 | + ) |
| 322 | + |
| 323 | + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook") |
| 324 | + |
| 325 | + # Get the appropriate dtype (torch.uint1 through torch.uint8) |
| 326 | + dtype = getattr(torch, f"uint{bits}") |
| 327 | + |
| 328 | + # Use user-provided filter_fn or default to Linear/Embedding layers |
| 329 | + filter_fn = kwargs.get( |
| 330 | + "filter_fn", |
| 331 | + lambda m, fqn: ( |
| 332 | + isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear) |
| 333 | + ), |
| 334 | + ) |
| 335 | + |
| 336 | + config = AOQuantizationConfig( |
| 337 | + ao_base_config=CodebookWeightOnlyConfig( |
| 338 | + dtype=dtype, |
| 339 | + block_size=block_size, |
| 340 | + ), |
| 341 | + filter_fn=filter_fn, |
| 342 | + ) |
| 343 | + |
| 344 | + quantization_recipe = QuantizationRecipe( |
| 345 | + quantizers=None, |
| 346 | + ao_quantization_configs=[config], |
| 347 | + ) |
| 348 | + |
| 349 | + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
| 350 | + |
| 351 | + return ExportRecipe( |
| 352 | + name=recipe_type.value, |
| 353 | + quantization_recipe=quantization_recipe, |
102 | 354 | lowering_recipe=lowering_recipe,
|
103 | 355 | )
|
104 | 356 |
|
105 | 357 | def _get_coreml_lowering_recipe(
|
106 | 358 | self,
|
107 |
| - compute_precision: ct.precision, |
| 359 | + compute_precision: ct.precision = ct.precision.FLOAT16, |
108 | 360 | **kwargs: Any,
|
109 | 361 | ) -> LoweringRecipe:
|
| 362 | + """Get CoreML lowering recipe with optional precision""" |
110 | 363 | compile_specs = CoreMLBackend.generate_compile_specs(
|
111 | 364 | compute_precision=compute_precision,
|
112 |
| - **kwargs, |
| 365 | + compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL), |
| 366 | + minimum_deployment_target=kwargs.get("minimum_deployment_target", None), |
113 | 367 | )
|
114 | 368 |
|
115 | 369 | minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
|
|
0 commit comments