|
6 | 6 | from typing import Any, Optional, Sequence
|
7 | 7 |
|
8 | 8 | import coremltools as ct
|
9 |
| -import torch |
10 | 9 |
|
11 | 10 | from executorch.backends.apple.coreml.compiler import CoreMLBackend
|
12 | 11 | from executorch.backends.apple.coreml.partition.coreml_partitioner import (
|
|
19 | 18 |
|
20 | 19 | from executorch.exir import EdgeCompileConfig
|
21 | 20 | from executorch.export import (
|
22 |
| - AOQuantizationConfig, |
23 | 21 | BackendRecipeProvider,
|
24 | 22 | ExportRecipe,
|
25 | 23 | LoweringRecipe,
|
26 |
| - QuantizationRecipe, |
27 | 24 | RecipeType,
|
28 | 25 | )
|
29 |
| -from torchao.quantization.granularity import PerAxis, PerGroup |
30 |
| -from torchao.quantization.quant_api import IntxWeightOnlyConfig |
31 | 26 |
|
32 | 27 |
|
33 | 28 | class CoreMLRecipeProvider(BackendRecipeProvider):
|
@@ -55,321 +50,66 @@ def create_recipe(
|
55 | 50 | # Validate kwargs
|
56 | 51 | self._validate_recipe_kwargs(recipe_type, **kwargs)
|
57 | 52 |
|
| 53 | + # Parse recipe type to get precision and compute unit |
| 54 | + precision = None |
58 | 55 | if recipe_type == CoreMLRecipeType.FP32:
|
59 |
| - return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs) |
| 56 | + precision = ct.precision.FLOAT32 |
60 | 57 | elif recipe_type == CoreMLRecipeType.FP16:
|
61 |
| - return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs) |
62 |
| - elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC: |
63 |
| - return self._build_pt2e_quantized_recipe( |
64 |
| - recipe_type, activation_dtype=torch.quint8, **kwargs |
65 |
| - ) |
66 |
| - elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY: |
67 |
| - return self._build_pt2e_quantized_recipe( |
68 |
| - recipe_type, activation_dtype=torch.float32, **kwargs |
69 |
| - ) |
70 |
| - elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL: |
71 |
| - return self._build_torchao_quantized_recipe( |
72 |
| - recipe_type, |
73 |
| - weight_dtype=torch.int4, |
74 |
| - is_per_channel=True, |
75 |
| - **kwargs, |
76 |
| - ) |
77 |
| - elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP: |
78 |
| - group_size = kwargs.pop("group_size", 32) |
79 |
| - return self._build_torchao_quantized_recipe( |
80 |
| - recipe_type, |
81 |
| - weight_dtype=torch.int4, |
82 |
| - is_per_channel=False, |
83 |
| - group_size=group_size, |
84 |
| - **kwargs, |
85 |
| - ) |
86 |
| - elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL: |
87 |
| - return self._build_torchao_quantized_recipe( |
88 |
| - recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs |
89 |
| - ) |
90 |
| - elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP: |
91 |
| - group_size = kwargs.pop("group_size", 32) |
92 |
| - return self._build_torchao_quantized_recipe( |
93 |
| - recipe_type, |
94 |
| - weight_dtype=torch.int8, |
95 |
| - is_per_channel=False, |
96 |
| - group_size=group_size, |
97 |
| - **kwargs, |
98 |
| - ) |
99 |
| - elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
100 |
| - bits = kwargs.pop("bits") |
101 |
| - block_size = kwargs.pop("block_size") |
102 |
| - return self._build_codebook_quantized_recipe( |
103 |
| - recipe_type, bits=bits, block_size=block_size, **kwargs |
104 |
| - ) |
| 58 | + precision = ct.precision.FLOAT16 |
105 | 59 |
|
106 |
| - return None |
| 60 | + if precision is None: |
| 61 | + raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") |
107 | 62 |
|
108 |
| - def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: |
109 |
| - """Validate kwargs for each recipe type""" |
110 |
| - expected_keys = self._get_expected_keys(recipe_type) |
| 63 | + return self._build_recipe(recipe_type, precision, **kwargs) |
111 | 64 |
|
| 65 | + def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: |
| 66 | + if not kwargs: |
| 67 | + return |
| 68 | + expected_keys = {"minimum_deployment_target", "compute_unit"} |
112 | 69 | unexpected = set(kwargs.keys()) - expected_keys
|
113 | 70 | if unexpected:
|
114 | 71 | raise ValueError(
|
115 |
| - f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" |
| 72 | + f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " |
| 73 | + f"Unexpected parameters: {list(unexpected)}" |
116 | 74 | )
|
117 |
| - |
118 |
| - self._validate_base_parameters(kwargs) |
119 |
| - self._validate_group_size_parameter(recipe_type, kwargs) |
120 |
| - self._validate_codebook_parameters(recipe_type, kwargs) |
121 |
| - |
122 |
| - def _get_expected_keys(self, recipe_type: RecipeType) -> set: |
123 |
| - """Get expected parameter keys for a recipe type""" |
124 |
| - common_keys = {"minimum_deployment_target", "compute_unit"} |
125 |
| - |
126 |
| - if recipe_type in [ |
127 |
| - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, |
128 |
| - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, |
129 |
| - ]: |
130 |
| - return common_keys | {"group_size", "filter_fn"} |
131 |
| - elif recipe_type in [ |
132 |
| - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, |
133 |
| - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, |
134 |
| - ]: |
135 |
| - return common_keys | {"filter_fn"} |
136 |
| - elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
137 |
| - return common_keys | {"bits", "block_size", "filter_fn"} |
138 |
| - else: |
139 |
| - return common_keys |
140 |
| - |
141 |
| - def _validate_base_parameters(self, kwargs: Any) -> None: |
142 |
| - """Validate minimum_deployment_target and compute_unit parameters""" |
143 | 75 | if "minimum_deployment_target" in kwargs:
|
144 | 76 | minimum_deployment_target = kwargs["minimum_deployment_target"]
|
145 | 77 | if not isinstance(minimum_deployment_target, ct.target):
|
146 | 78 | raise ValueError(
|
147 | 79 | f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
|
148 | 80 | )
|
149 |
| - |
150 | 81 | if "compute_unit" in kwargs:
|
151 | 82 | compute_unit = kwargs["compute_unit"]
|
152 | 83 | if not isinstance(compute_unit, ct.ComputeUnit):
|
153 | 84 | raise ValueError(
|
154 | 85 | f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
|
155 | 86 | )
|
156 | 87 |
|
157 |
| - def _validate_group_size_parameter( |
158 |
| - self, recipe_type: RecipeType, kwargs: Any |
159 |
| - ) -> None: |
160 |
| - """Validate group_size parameter for applicable recipe types""" |
161 |
| - if ( |
162 |
| - recipe_type |
163 |
| - in [ |
164 |
| - CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, |
165 |
| - CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, |
166 |
| - ] |
167 |
| - and "group_size" in kwargs |
168 |
| - ): |
169 |
| - group_size = kwargs["group_size"] |
170 |
| - if not isinstance(group_size, int): |
171 |
| - raise ValueError( |
172 |
| - f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" |
173 |
| - ) |
174 |
| - if group_size <= 0: |
175 |
| - raise ValueError( |
176 |
| - f"Parameter 'group_size' must be positive, got: {group_size}" |
177 |
| - ) |
178 |
| - |
179 |
| - def _validate_codebook_parameters( |
180 |
| - self, recipe_type: RecipeType, kwargs: Any |
181 |
| - ) -> None: |
182 |
| - """Validate bits and block_size parameters for codebook recipe type""" |
183 |
| - if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: |
184 |
| - return |
185 |
| - |
186 |
| - # Both bits and block_size must be present |
187 |
| - if not ("bits" in kwargs and "block_size" in kwargs): |
188 |
| - raise ValueError( |
189 |
| - "Parameters 'bits' and 'block_size' must be present for codebook recipes" |
190 |
| - ) |
191 |
| - |
192 |
| - if "bits" in kwargs: |
193 |
| - bits = kwargs["bits"] |
194 |
| - if not isinstance(bits, int): |
195 |
| - raise ValueError( |
196 |
| - f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}" |
197 |
| - ) |
198 |
| - if not (1 <= bits <= 8): |
199 |
| - raise ValueError( |
200 |
| - f"Parameter 'bits' must be between 1 and 8, got: {bits}" |
201 |
| - ) |
202 |
| - |
203 |
| - if "block_size" in kwargs: |
204 |
| - block_size = kwargs["block_size"] |
205 |
| - if not isinstance(block_size, list): |
206 |
| - raise ValueError( |
207 |
| - f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}" |
208 |
| - ) |
209 |
| - |
210 |
| - def _validate_and_set_deployment_target( |
211 |
| - self, kwargs: Any, min_target: ct.target, quantization_type: str |
212 |
| - ) -> None: |
213 |
| - """Validate or set minimum deployment target for quantization recipes""" |
214 |
| - minimum_deployment_target = kwargs.get("minimum_deployment_target", None) |
215 |
| - if minimum_deployment_target and minimum_deployment_target < min_target: |
216 |
| - raise ValueError( |
217 |
| - f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization" |
218 |
| - ) |
219 |
| - else: |
220 |
| - # Default to the minimum target for this quantization type |
221 |
| - kwargs["minimum_deployment_target"] = min_target |
222 |
| - |
223 |
| - def _build_fp_recipe( |
| 88 | + def _build_recipe( |
224 | 89 | self,
|
225 | 90 | recipe_type: RecipeType,
|
226 | 91 | precision: ct.precision,
|
227 | 92 | **kwargs: Any,
|
228 | 93 | ) -> ExportRecipe:
|
229 |
| - """Build FP32/FP16 recipe""" |
230 | 94 | lowering_recipe = self._get_coreml_lowering_recipe(
|
231 | 95 | compute_precision=precision,
|
232 | 96 | **kwargs,
|
233 | 97 | )
|
234 | 98 |
|
235 | 99 | return ExportRecipe(
|
236 | 100 | name=recipe_type.value,
|
237 |
| - lowering_recipe=lowering_recipe, |
238 |
| - ) |
239 |
| - |
240 |
| - def _build_pt2e_quantized_recipe( |
241 |
| - self, |
242 |
| - recipe_type: RecipeType, |
243 |
| - activation_dtype: torch.dtype, |
244 |
| - **kwargs: Any, |
245 |
| - ) -> ExportRecipe: |
246 |
| - """Build PT2E-based quantization recipe""" |
247 |
| - from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer |
248 |
| - |
249 |
| - self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e") |
250 |
| - |
251 |
| - # Validate activation_dtype |
252 |
| - assert activation_dtype in [ |
253 |
| - torch.quint8, |
254 |
| - torch.float32, |
255 |
| - ], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}" |
256 |
| - |
257 |
| - # Create quantization config |
258 |
| - config = ct.optimize.torch.quantization.LinearQuantizerConfig( |
259 |
| - global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( |
260 |
| - quantization_scheme="symmetric", |
261 |
| - activation_dtype=activation_dtype, |
262 |
| - weight_dtype=torch.qint8, |
263 |
| - weight_per_channel=True, |
264 |
| - ) |
265 |
| - ) |
266 |
| - |
267 |
| - quantizer = CoreMLQuantizer(config) |
268 |
| - quantization_recipe = QuantizationRecipe(quantizers=[quantizer]) |
269 |
| - |
270 |
| - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
271 |
| - |
272 |
| - return ExportRecipe( |
273 |
| - name=recipe_type.value, |
274 |
| - quantization_recipe=quantization_recipe, |
275 |
| - lowering_recipe=lowering_recipe, |
276 |
| - ) |
277 |
| - |
278 |
| - def _build_torchao_quantized_recipe( |
279 |
| - self, |
280 |
| - recipe_type: RecipeType, |
281 |
| - weight_dtype: torch.dtype, |
282 |
| - is_per_channel: bool, |
283 |
| - group_size: int = 32, |
284 |
| - **kwargs: Any, |
285 |
| - ) -> ExportRecipe: |
286 |
| - """Build TorchAO-based quantization recipe""" |
287 |
| - if is_per_channel: |
288 |
| - weight_granularity = PerAxis(axis=0) |
289 |
| - else: |
290 |
| - weight_granularity = PerGroup(group_size=group_size) |
291 |
| - |
292 |
| - # Use user-provided filter_fn if provided |
293 |
| - filter_fn = kwargs.get("filter_fn", None) |
294 |
| - config = AOQuantizationConfig( |
295 |
| - ao_base_config=IntxWeightOnlyConfig( |
296 |
| - weight_dtype=weight_dtype, |
297 |
| - granularity=weight_granularity, |
298 |
| - ), |
299 |
| - filter_fn=filter_fn, |
300 |
| - ) |
301 |
| - |
302 |
| - quantization_recipe = QuantizationRecipe( |
303 |
| - quantizers=None, |
304 |
| - ao_quantization_configs=[config], |
305 |
| - ) |
306 |
| - |
307 |
| - # override minimum_deployment_target to ios18 for torchao (GH issue #13122) |
308 |
| - self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") |
309 |
| - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
310 |
| - |
311 |
| - return ExportRecipe( |
312 |
| - name=recipe_type.value, |
313 |
| - quantization_recipe=quantization_recipe, |
314 |
| - lowering_recipe=lowering_recipe, |
315 |
| - ) |
316 |
| - |
317 |
| - def _build_codebook_quantized_recipe( |
318 |
| - self, |
319 |
| - recipe_type: RecipeType, |
320 |
| - bits: int, |
321 |
| - block_size: list, |
322 |
| - **kwargs: Any, |
323 |
| - ) -> ExportRecipe: |
324 |
| - """Build codebook/palettization quantization recipe""" |
325 |
| - from torchao.prototype.quantization.codebook_coreml import ( |
326 |
| - CodebookWeightOnlyConfig, |
327 |
| - ) |
328 |
| - |
329 |
| - self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook") |
330 |
| - |
331 |
| - # Get the appropriate dtype (torch.uint1 through torch.uint8) |
332 |
| - dtype = getattr(torch, f"uint{bits}") |
333 |
| - |
334 |
| - # Use user-provided filter_fn or default to Linear/Embedding layers |
335 |
| - filter_fn = kwargs.get( |
336 |
| - "filter_fn", |
337 |
| - lambda m, fqn: ( |
338 |
| - isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear) |
339 |
| - ), |
340 |
| - ) |
341 |
| - |
342 |
| - config = AOQuantizationConfig( |
343 |
| - ao_base_config=CodebookWeightOnlyConfig( |
344 |
| - dtype=dtype, |
345 |
| - block_size=block_size, |
346 |
| - ), |
347 |
| - filter_fn=filter_fn, |
348 |
| - ) |
349 |
| - |
350 |
| - quantization_recipe = QuantizationRecipe( |
351 |
| - quantizers=None, |
352 |
| - ao_quantization_configs=[config], |
353 |
| - ) |
354 |
| - |
355 |
| - lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) |
356 |
| - |
357 |
| - return ExportRecipe( |
358 |
| - name=recipe_type.value, |
359 |
| - quantization_recipe=quantization_recipe, |
| 101 | + quantization_recipe=None, # TODO - add quantization recipe |
360 | 102 | lowering_recipe=lowering_recipe,
|
361 | 103 | )
|
362 | 104 |
|
363 | 105 | def _get_coreml_lowering_recipe(
|
364 | 106 | self,
|
365 |
| - compute_precision: ct.precision = ct.precision.FLOAT16, |
| 107 | + compute_precision: ct.precision, |
366 | 108 | **kwargs: Any,
|
367 | 109 | ) -> LoweringRecipe:
|
368 |
| - """Get CoreML lowering recipe with optional precision""" |
369 | 110 | compile_specs = CoreMLBackend.generate_compile_specs(
|
370 | 111 | compute_precision=compute_precision,
|
371 |
| - compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL), |
372 |
| - minimum_deployment_target=kwargs.get("minimum_deployment_target", None), |
| 112 | + **kwargs, |
373 | 113 | )
|
374 | 114 |
|
375 | 115 | minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
|
|
0 commit comments