Skip to content

Commit 0ad3df9

Browse files
[executorch] Add coreml quant recipes (#13441)
Co-authored-by: Abhinay Kukkadapu <[email protected]>
1 parent 5b96542 commit 0ad3df9

File tree

4 files changed

+801
-175
lines changed

4 files changed

+801
-175
lines changed

backends/apple/coreml/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ runtime.python_test(
120120
"test/*.py",
121121
]),
122122
deps = [
123+
"fbsource//third-party/pypi/coremltools:coremltools",
123124
"fbsource//third-party/pypi/pytest:pytest",
124125
":partitioner",
125126
":quantizer",
126127
":recipes",
127128
"//caffe2:torch",
128129
"//pytorch/vision:torchvision",
130+
"fbsource//third-party/pypi/scikit-learn:scikit-learn",
129131
],
130132
)

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 277 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Sequence
77

88
import coremltools as ct
9+
import torch
910

1011
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1112
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
@@ -18,11 +19,15 @@
1819

1920
from executorch.exir import EdgeCompileConfig
2021
from executorch.export import (
22+
AOQuantizationConfig,
2123
BackendRecipeProvider,
2224
ExportRecipe,
2325
LoweringRecipe,
26+
QuantizationRecipe,
2427
RecipeType,
2528
)
29+
from torchao.quantization.granularity import PerAxis, PerGroup
30+
from torchao.quantization.quant_api import IntxWeightOnlyConfig
2631

2732

2833
class CoreMLRecipeProvider(BackendRecipeProvider):
@@ -50,66 +55,321 @@ def create_recipe(
5055
# Validate kwargs
5156
self._validate_recipe_kwargs(recipe_type, **kwargs)
5257

53-
# Parse recipe type to get precision and compute unit
54-
precision = None
5558
if recipe_type == CoreMLRecipeType.FP32:
56-
precision = ct.precision.FLOAT32
59+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
5760
elif recipe_type == CoreMLRecipeType.FP16:
58-
precision = ct.precision.FLOAT16
59-
60-
if precision is None:
61-
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
61+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
62+
elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC:
63+
return self._build_pt2e_quantized_recipe(
64+
recipe_type, activation_dtype=torch.quint8, **kwargs
65+
)
66+
elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY:
67+
return self._build_pt2e_quantized_recipe(
68+
recipe_type, activation_dtype=torch.float32, **kwargs
69+
)
70+
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL:
71+
return self._build_torchao_quantized_recipe(
72+
recipe_type,
73+
weight_dtype=torch.int4,
74+
is_per_channel=True,
75+
**kwargs,
76+
)
77+
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP:
78+
group_size = kwargs.pop("group_size", 32)
79+
return self._build_torchao_quantized_recipe(
80+
recipe_type,
81+
weight_dtype=torch.int4,
82+
is_per_channel=False,
83+
group_size=group_size,
84+
**kwargs,
85+
)
86+
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL:
87+
return self._build_torchao_quantized_recipe(
88+
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
89+
)
90+
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP:
91+
group_size = kwargs.pop("group_size", 32)
92+
return self._build_torchao_quantized_recipe(
93+
recipe_type,
94+
weight_dtype=torch.int8,
95+
is_per_channel=False,
96+
group_size=group_size,
97+
**kwargs,
98+
)
99+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
100+
bits = kwargs.pop("bits")
101+
block_size = kwargs.pop("block_size")
102+
return self._build_codebook_quantized_recipe(
103+
recipe_type, bits=bits, block_size=block_size, **kwargs
104+
)
62105

63-
return self._build_recipe(recipe_type, precision, **kwargs)
106+
return None
64107

65108
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
66-
if not kwargs:
67-
return
68-
expected_keys = {"minimum_deployment_target", "compute_unit"}
109+
"""Validate kwargs for each recipe type"""
110+
expected_keys = self._get_expected_keys(recipe_type)
111+
69112
unexpected = set(kwargs.keys()) - expected_keys
70113
if unexpected:
71114
raise ValueError(
72-
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
73-
f"Unexpected parameters: {list(unexpected)}"
115+
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
74116
)
117+
118+
self._validate_base_parameters(kwargs)
119+
self._validate_group_size_parameter(recipe_type, kwargs)
120+
self._validate_codebook_parameters(recipe_type, kwargs)
121+
122+
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
123+
"""Get expected parameter keys for a recipe type"""
124+
common_keys = {"minimum_deployment_target", "compute_unit"}
125+
126+
if recipe_type in [
127+
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
128+
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,
129+
]:
130+
return common_keys | {"group_size", "filter_fn"}
131+
elif recipe_type in [
132+
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL,
133+
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL,
134+
]:
135+
return common_keys | {"filter_fn"}
136+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
137+
return common_keys | {"bits", "block_size", "filter_fn"}
138+
else:
139+
return common_keys
140+
141+
def _validate_base_parameters(self, kwargs: Any) -> None:
142+
"""Validate minimum_deployment_target and compute_unit parameters"""
75143
if "minimum_deployment_target" in kwargs:
76144
minimum_deployment_target = kwargs["minimum_deployment_target"]
77145
if not isinstance(minimum_deployment_target, ct.target):
78146
raise ValueError(
79147
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
80148
)
149+
81150
if "compute_unit" in kwargs:
82151
compute_unit = kwargs["compute_unit"]
83152
if not isinstance(compute_unit, ct.ComputeUnit):
84153
raise ValueError(
85154
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
86155
)
87156

88-
def _build_recipe(
157+
def _validate_group_size_parameter(
158+
self, recipe_type: RecipeType, kwargs: Any
159+
) -> None:
160+
"""Validate group_size parameter for applicable recipe types"""
161+
if (
162+
recipe_type
163+
in [
164+
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
165+
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,
166+
]
167+
and "group_size" in kwargs
168+
):
169+
group_size = kwargs["group_size"]
170+
if not isinstance(group_size, int):
171+
raise ValueError(
172+
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
173+
)
174+
if group_size <= 0:
175+
raise ValueError(
176+
f"Parameter 'group_size' must be positive, got: {group_size}"
177+
)
178+
179+
def _validate_codebook_parameters(
180+
self, recipe_type: RecipeType, kwargs: Any
181+
) -> None:
182+
"""Validate bits and block_size parameters for codebook recipe type"""
183+
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
184+
return
185+
186+
# Both bits and block_size must be present
187+
if not ("bits" in kwargs and "block_size" in kwargs):
188+
raise ValueError(
189+
"Parameters 'bits' and 'block_size' must be present for codebook recipes"
190+
)
191+
192+
if "bits" in kwargs:
193+
bits = kwargs["bits"]
194+
if not isinstance(bits, int):
195+
raise ValueError(
196+
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
197+
)
198+
if not (1 <= bits <= 8):
199+
raise ValueError(
200+
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
201+
)
202+
203+
if "block_size" in kwargs:
204+
block_size = kwargs["block_size"]
205+
if not isinstance(block_size, list):
206+
raise ValueError(
207+
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
208+
)
209+
210+
def _validate_and_set_deployment_target(
211+
self, kwargs: Any, min_target: ct.target, quantization_type: str
212+
) -> None:
213+
"""Validate or set minimum deployment target for quantization recipes"""
214+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
215+
if minimum_deployment_target and minimum_deployment_target < min_target:
216+
raise ValueError(
217+
f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization"
218+
)
219+
else:
220+
# Default to the minimum target for this quantization type
221+
kwargs["minimum_deployment_target"] = min_target
222+
223+
def _build_fp_recipe(
89224
self,
90225
recipe_type: RecipeType,
91226
precision: ct.precision,
92227
**kwargs: Any,
93228
) -> ExportRecipe:
229+
"""Build FP32/FP16 recipe"""
94230
lowering_recipe = self._get_coreml_lowering_recipe(
95231
compute_precision=precision,
96232
**kwargs,
97233
)
98234

99235
return ExportRecipe(
100236
name=recipe_type.value,
101-
quantization_recipe=None, # TODO - add quantization recipe
237+
lowering_recipe=lowering_recipe,
238+
)
239+
240+
def _build_pt2e_quantized_recipe(
241+
self,
242+
recipe_type: RecipeType,
243+
activation_dtype: torch.dtype,
244+
**kwargs: Any,
245+
) -> ExportRecipe:
246+
"""Build PT2E-based quantization recipe"""
247+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
248+
249+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e")
250+
251+
# Validate activation_dtype
252+
assert activation_dtype in [
253+
torch.quint8,
254+
torch.float32,
255+
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"
256+
257+
# Create quantization config
258+
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
259+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
260+
quantization_scheme="symmetric",
261+
activation_dtype=activation_dtype,
262+
weight_dtype=torch.qint8,
263+
weight_per_channel=True,
264+
)
265+
)
266+
267+
quantizer = CoreMLQuantizer(config)
268+
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])
269+
270+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
271+
272+
return ExportRecipe(
273+
name=recipe_type.value,
274+
quantization_recipe=quantization_recipe,
275+
lowering_recipe=lowering_recipe,
276+
)
277+
278+
def _build_torchao_quantized_recipe(
279+
self,
280+
recipe_type: RecipeType,
281+
weight_dtype: torch.dtype,
282+
is_per_channel: bool,
283+
group_size: int = 32,
284+
**kwargs: Any,
285+
) -> ExportRecipe:
286+
"""Build TorchAO-based quantization recipe"""
287+
if is_per_channel:
288+
weight_granularity = PerAxis(axis=0)
289+
else:
290+
weight_granularity = PerGroup(group_size=group_size)
291+
292+
# Use user-provided filter_fn if provided
293+
filter_fn = kwargs.get("filter_fn", None)
294+
config = AOQuantizationConfig(
295+
ao_base_config=IntxWeightOnlyConfig(
296+
weight_dtype=weight_dtype,
297+
granularity=weight_granularity,
298+
),
299+
filter_fn=filter_fn,
300+
)
301+
302+
quantization_recipe = QuantizationRecipe(
303+
quantizers=None,
304+
ao_quantization_configs=[config],
305+
)
306+
307+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
308+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
309+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
310+
311+
return ExportRecipe(
312+
name=recipe_type.value,
313+
quantization_recipe=quantization_recipe,
314+
lowering_recipe=lowering_recipe,
315+
)
316+
317+
def _build_codebook_quantized_recipe(
318+
self,
319+
recipe_type: RecipeType,
320+
bits: int,
321+
block_size: list,
322+
**kwargs: Any,
323+
) -> ExportRecipe:
324+
"""Build codebook/palettization quantization recipe"""
325+
from torchao.prototype.quantization.codebook_coreml import (
326+
CodebookWeightOnlyConfig,
327+
)
328+
329+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook")
330+
331+
# Get the appropriate dtype (torch.uint1 through torch.uint8)
332+
dtype = getattr(torch, f"uint{bits}")
333+
334+
# Use user-provided filter_fn or default to Linear/Embedding layers
335+
filter_fn = kwargs.get(
336+
"filter_fn",
337+
lambda m, fqn: (
338+
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
339+
),
340+
)
341+
342+
config = AOQuantizationConfig(
343+
ao_base_config=CodebookWeightOnlyConfig(
344+
dtype=dtype,
345+
block_size=block_size,
346+
),
347+
filter_fn=filter_fn,
348+
)
349+
350+
quantization_recipe = QuantizationRecipe(
351+
quantizers=None,
352+
ao_quantization_configs=[config],
353+
)
354+
355+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
356+
357+
return ExportRecipe(
358+
name=recipe_type.value,
359+
quantization_recipe=quantization_recipe,
102360
lowering_recipe=lowering_recipe,
103361
)
104362

105363
def _get_coreml_lowering_recipe(
106364
self,
107-
compute_precision: ct.precision,
365+
compute_precision: ct.precision = ct.precision.FLOAT16,
108366
**kwargs: Any,
109367
) -> LoweringRecipe:
368+
"""Get CoreML lowering recipe with optional precision"""
110369
compile_specs = CoreMLBackend.generate_compile_specs(
111370
compute_precision=compute_precision,
112-
**kwargs,
371+
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
372+
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
113373
)
114374

115375
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)

0 commit comments

Comments
 (0)