Skip to content

Commit c0c32b6

Browse files
Add coreml quant recipes
ghstack-source-id: 76f6fc8 ghstack-comment-id: 3172341606 Pull-Request: #13265
1 parent aaafcbd commit c0c32b6

File tree

5 files changed

+749
-183
lines changed

5 files changed

+749
-183
lines changed

backends/apple/coreml/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ runtime.python_test(
120120
"test/*.py",
121121
]),
122122
deps = [
123+
"fbsource//third-party/pypi/coremltools:coremltools",
123124
"fbsource//third-party/pypi/pytest:pytest",
124125
":partitioner",
125126
":quantizer",

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 270 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Sequence
77

88
import coremltools as ct
9+
import torch
910

1011
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1112
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
@@ -18,11 +19,15 @@
1819

1920
from executorch.exir import EdgeCompileConfig
2021
from executorch.export import (
22+
AOQuantizationConfig,
2123
BackendRecipeProvider,
2224
ExportRecipe,
2325
LoweringRecipe,
26+
QuantizationRecipe,
2427
RecipeType,
2528
)
29+
from torchao.quantization.granularity import PerAxis, PerGroup
30+
from torchao.quantization.quant_api import IntxWeightOnlyConfig
2631

2732

2833
class CoreMLRecipeProvider(BackendRecipeProvider):
@@ -50,66 +55,314 @@ def create_recipe(
5055
# Validate kwargs
5156
self._validate_recipe_kwargs(recipe_type, **kwargs)
5257

53-
# Parse recipe type to get precision and compute unit
54-
precision = None
5558
if recipe_type == CoreMLRecipeType.FP32:
56-
precision = ct.precision.FLOAT32
59+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
5760
elif recipe_type == CoreMLRecipeType.FP16:
58-
precision = ct.precision.FLOAT16
59-
60-
if precision is None:
61-
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
61+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
62+
elif recipe_type == CoreMLRecipeType.INT8_STATIC:
63+
return self._build_pt2e_quantized_recipe(
64+
recipe_type, activation_dtype=torch.quint8, **kwargs
65+
)
66+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY:
67+
return self._build_pt2e_quantized_recipe(
68+
recipe_type, activation_dtype=torch.float32, **kwargs
69+
)
70+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL:
71+
return self._build_torchao_quantized_recipe(
72+
recipe_type,
73+
weight_dtype=torch.int4,
74+
is_per_channel=True,
75+
**kwargs,
76+
)
77+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP:
78+
group_size = kwargs.pop("group_size", 32)
79+
return self._build_torchao_quantized_recipe(
80+
recipe_type,
81+
weight_dtype=torch.int4,
82+
is_per_channel=False,
83+
group_size=group_size,
84+
**kwargs,
85+
)
86+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL:
87+
return self._build_torchao_quantized_recipe(
88+
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
89+
)
90+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP:
91+
group_size = kwargs.pop("group_size", 32)
92+
return self._build_torchao_quantized_recipe(
93+
recipe_type,
94+
weight_dtype=torch.int8,
95+
is_per_channel=False,
96+
group_size=group_size,
97+
**kwargs,
98+
)
99+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
100+
bits = kwargs.pop("bits", 3)
101+
block_size = kwargs.pop("block_size", [-1, 16])
102+
return self._build_codebook_quantized_recipe(
103+
recipe_type, bits=bits, block_size=block_size, **kwargs
104+
)
62105

63-
return self._build_recipe(recipe_type, precision, **kwargs)
106+
return None
64107

65108
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
66-
if not kwargs:
67-
return
68-
expected_keys = {"minimum_deployment_target", "compute_unit"}
109+
"""Validate kwargs for each recipe type"""
110+
expected_keys = self._get_expected_keys(recipe_type)
111+
69112
unexpected = set(kwargs.keys()) - expected_keys
70113
if unexpected:
71114
raise ValueError(
72-
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
73-
f"Unexpected parameters: {list(unexpected)}"
115+
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
74116
)
117+
118+
self._validate_base_parameters(kwargs)
119+
self._validate_group_size_parameter(recipe_type, kwargs)
120+
self._validate_codebook_parameters(recipe_type, kwargs)
121+
122+
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
123+
"""Get expected parameter keys for a recipe type"""
124+
common_keys = {"minimum_deployment_target", "compute_unit"}
125+
126+
if recipe_type in [
127+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
128+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
129+
]:
130+
return common_keys | {"group_size", "filter_fn"}
131+
elif recipe_type in [
132+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL,
133+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL,
134+
]:
135+
return common_keys | {"filter_fn"}
136+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
137+
return common_keys | {"bits", "block_size", "filter_fn"}
138+
else:
139+
return common_keys
140+
141+
def _validate_base_parameters(self, kwargs: Any) -> None:
142+
"""Validate minimum_deployment_target and compute_unit parameters"""
75143
if "minimum_deployment_target" in kwargs:
76144
minimum_deployment_target = kwargs["minimum_deployment_target"]
77145
if not isinstance(minimum_deployment_target, ct.target):
78146
raise ValueError(
79147
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
80148
)
149+
81150
if "compute_unit" in kwargs:
82151
compute_unit = kwargs["compute_unit"]
83152
if not isinstance(compute_unit, ct.ComputeUnit):
84153
raise ValueError(
85154
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
86155
)
87156

88-
def _build_recipe(
157+
def _validate_group_size_parameter(
158+
self, recipe_type: RecipeType, kwargs: Any
159+
) -> None:
160+
"""Validate group_size parameter for applicable recipe types"""
161+
if (
162+
recipe_type
163+
in [
164+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
165+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
166+
]
167+
and "group_size" in kwargs
168+
):
169+
group_size = kwargs["group_size"]
170+
if not isinstance(group_size, int):
171+
raise ValueError(
172+
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
173+
)
174+
if group_size <= 0:
175+
raise ValueError(
176+
f"Parameter 'group_size' must be positive, got: {group_size}"
177+
)
178+
179+
def _validate_codebook_parameters(
180+
self, recipe_type: RecipeType, kwargs: Any
181+
) -> None:
182+
"""Validate bits and block_size parameters for codebook recipe type"""
183+
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
184+
return
185+
186+
if "bits" in kwargs:
187+
bits = kwargs["bits"]
188+
if not isinstance(bits, int):
189+
raise ValueError(
190+
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
191+
)
192+
if not (1 <= bits <= 8):
193+
raise ValueError(
194+
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
195+
)
196+
197+
if "block_size" in kwargs:
198+
block_size = kwargs["block_size"]
199+
if not isinstance(block_size, list):
200+
raise ValueError(
201+
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
202+
)
203+
204+
def _build_fp_recipe(
89205
self,
90206
recipe_type: RecipeType,
91207
precision: ct.precision,
92208
**kwargs: Any,
93209
) -> ExportRecipe:
210+
"""Build FP32/FP16 recipe"""
94211
lowering_recipe = self._get_coreml_lowering_recipe(
95212
compute_precision=precision,
96213
**kwargs,
97214
)
98215

99216
return ExportRecipe(
100217
name=recipe_type.value,
101-
quantization_recipe=None, # TODO - add quantization recipe
218+
lowering_recipe=lowering_recipe,
219+
)
220+
221+
def _build_pt2e_quantized_recipe(
222+
self,
223+
recipe_type: RecipeType,
224+
activation_dtype: torch.dtype,
225+
**kwargs: Any,
226+
) -> ExportRecipe:
227+
"""Build PT2E-based quantization recipe"""
228+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
229+
230+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
231+
if minimum_deployment_target and minimum_deployment_target < ct.target.iOS17:
232+
raise ValueError(
233+
"minimum_deployment_target must be iOS17 or higher for codebook quantization"
234+
)
235+
# Default to iOS17 for quantization
236+
kwargs["minimum_deployment_target"] = ct.target.iOS17
237+
238+
# Validate activation_dtype
239+
assert activation_dtype in [
240+
torch.quint8,
241+
torch.float32,
242+
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"
243+
244+
# Create quantization config
245+
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
246+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
247+
quantization_scheme="symmetric",
248+
activation_dtype=activation_dtype,
249+
weight_dtype=torch.qint8,
250+
weight_per_channel=True,
251+
)
252+
)
253+
254+
quantizer = CoreMLQuantizer(config)
255+
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])
256+
257+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
258+
259+
return ExportRecipe(
260+
name=recipe_type.value,
261+
quantization_recipe=quantization_recipe,
262+
lowering_recipe=lowering_recipe,
263+
)
264+
265+
def _build_torchao_quantized_recipe(
266+
self,
267+
recipe_type: RecipeType,
268+
weight_dtype: torch.dtype,
269+
is_per_channel: bool,
270+
group_size: int = 32,
271+
**kwargs: Any,
272+
) -> ExportRecipe:
273+
"""Build TorchAO-based quantization recipe"""
274+
if is_per_channel:
275+
weight_granularity = PerAxis(axis=0)
276+
else:
277+
weight_granularity = PerGroup(group_size=group_size)
278+
279+
# Use user-provided filter_fn or default to Linear/Embedding layers
280+
filter_fn = kwargs.get("filter_fn", None)
281+
config = AOQuantizationConfig(
282+
ao_base_config=IntxWeightOnlyConfig(
283+
weight_dtype=weight_dtype,
284+
granularity=weight_granularity,
285+
),
286+
filter_fn=filter_fn,
287+
)
288+
289+
quantization_recipe = QuantizationRecipe(
290+
quantizers=None,
291+
ao_quantization_configs=[config],
292+
)
293+
294+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
295+
kwargs["minimum_deployment_target"] = ct.target.iOS18
296+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
297+
298+
return ExportRecipe(
299+
name=recipe_type.value,
300+
quantization_recipe=quantization_recipe,
301+
lowering_recipe=lowering_recipe,
302+
)
303+
304+
def _build_codebook_quantized_recipe(
305+
self,
306+
recipe_type: RecipeType,
307+
bits: int,
308+
block_size: list,
309+
**kwargs: Any,
310+
) -> ExportRecipe:
311+
"""Build codebook/palettization quantization recipe"""
312+
from torchao.prototype.quantization.codebook_coreml import (
313+
CodebookWeightOnlyConfig,
314+
)
315+
316+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
317+
if minimum_deployment_target and minimum_deployment_target < ct.target.iOS18:
318+
raise ValueError(
319+
"minimum_deployment_target must be iOS18 or higher for codebook quantization"
320+
)
321+
# Default to iOS18 for codebook quantization
322+
kwargs["minimum_deployment_target"] = ct.target.iOS18
323+
324+
# Get the appropriate dtype (torch.uint1 through torch.uint8)
325+
dtype = getattr(torch, f"uint{bits}")
326+
327+
# Use user-provided filter_fn or default to Linear/Embedding layers
328+
filter_fn = kwargs.get(
329+
"filter_fn",
330+
lambda m, _: (
331+
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
332+
),
333+
)
334+
335+
config = AOQuantizationConfig(
336+
ao_base_config=CodebookWeightOnlyConfig(
337+
dtype=dtype,
338+
block_size=block_size,
339+
),
340+
filter_fn=filter_fn,
341+
)
342+
343+
quantization_recipe = QuantizationRecipe(
344+
quantizers=None,
345+
ao_quantization_configs=[config],
346+
)
347+
348+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
349+
350+
return ExportRecipe(
351+
name=recipe_type.value,
352+
quantization_recipe=quantization_recipe,
102353
lowering_recipe=lowering_recipe,
103354
)
104355

105356
def _get_coreml_lowering_recipe(
106357
self,
107-
compute_precision: ct.precision,
358+
compute_precision: ct.precision = ct.precision.FLOAT16,
108359
**kwargs: Any,
109360
) -> LoweringRecipe:
361+
"""Get CoreML lowering recipe with optional precision"""
110362
compile_specs = CoreMLBackend.generate_compile_specs(
111363
compute_precision=compute_precision,
112-
**kwargs,
364+
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
365+
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
113366
)
114367

115368
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)

0 commit comments

Comments
 (0)