Skip to content

Commit bf7413d

Browse files
Add coreml quant recipes
ghstack-source-id: 1bb73e0 ghstack-comment-id: 3172341606 Pull-Request: #13265
1 parent cb80c69 commit bf7413d

File tree

5 files changed

+750
-183
lines changed

5 files changed

+750
-183
lines changed

backends/apple/coreml/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ runtime.python_test(
120120
"test/*.py",
121121
]),
122122
deps = [
123+
"fbsource//third-party/pypi/coremltools:coremltools",
123124
"fbsource//third-party/pypi/pytest:pytest",
124125
":partitioner",
125126
":quantizer",

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 271 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Sequence
77

88
import coremltools as ct
9+
import torch
910

1011
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1112
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
@@ -18,11 +19,16 @@
1819

1920
from executorch.exir import EdgeCompileConfig
2021
from executorch.export import (
22+
AOQuantizationConfig,
2123
BackendRecipeProvider,
2224
ExportRecipe,
2325
LoweringRecipe,
26+
QuantizationRecipe,
2427
RecipeType,
2528
)
29+
from torchao.quantization.granularity import PerAxis, PerGroup
30+
from torchao.quantization.quant_api import IntxWeightOnlyConfig
31+
from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH
2632

2733

2834
class CoreMLRecipeProvider(BackendRecipeProvider):
@@ -50,66 +56,314 @@ def create_recipe(
5056
# Validate kwargs
5157
self._validate_recipe_kwargs(recipe_type, **kwargs)
5258

53-
# Parse recipe type to get precision and compute unit
54-
precision = None
5559
if recipe_type == CoreMLRecipeType.FP32:
56-
precision = ct.precision.FLOAT32
60+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
5761
elif recipe_type == CoreMLRecipeType.FP16:
58-
precision = ct.precision.FLOAT16
59-
60-
if precision is None:
61-
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
62+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
63+
elif recipe_type == CoreMLRecipeType.INT8_STATIC:
64+
return self._build_pt2e_quantized_recipe(
65+
recipe_type, activation_dtype=torch.quint8, **kwargs
66+
)
67+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY:
68+
return self._build_pt2e_quantized_recipe(
69+
recipe_type, activation_dtype=torch.float32, **kwargs
70+
)
71+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL:
72+
return self._build_torchao_quantized_recipe(
73+
recipe_type,
74+
weight_dtype=torch.int4,
75+
is_per_channel=True,
76+
**kwargs,
77+
)
78+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP:
79+
group_size = kwargs.pop("group_size", 32)
80+
return self._build_torchao_quantized_recipe(
81+
recipe_type,
82+
weight_dtype=torch.int4,
83+
is_per_channel=False,
84+
group_size=group_size,
85+
**kwargs,
86+
)
87+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL:
88+
return self._build_torchao_quantized_recipe(
89+
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
90+
)
91+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP:
92+
group_size = kwargs.pop("group_size", 32)
93+
return self._build_torchao_quantized_recipe(
94+
recipe_type,
95+
weight_dtype=torch.int8,
96+
is_per_channel=False,
97+
group_size=group_size,
98+
**kwargs,
99+
)
100+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
101+
bits = kwargs.pop("bits", 3)
102+
block_size = kwargs.pop("block_size", [-1, 16])
103+
return self._build_codebook_quantized_recipe(
104+
recipe_type, bits=bits, block_size=block_size, **kwargs
105+
)
62106

63-
return self._build_recipe(recipe_type, precision, **kwargs)
107+
return None
64108

65109
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
66-
if not kwargs:
67-
return
68-
expected_keys = {"minimum_deployment_target", "compute_unit"}
110+
"""Validate kwargs for each recipe type"""
111+
expected_keys = self._get_expected_keys(recipe_type)
112+
69113
unexpected = set(kwargs.keys()) - expected_keys
70114
if unexpected:
71115
raise ValueError(
72-
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
73-
f"Unexpected parameters: {list(unexpected)}"
116+
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
74117
)
118+
119+
self._validate_base_parameters(kwargs)
120+
self._validate_group_size_parameter(recipe_type, kwargs)
121+
self._validate_codebook_parameters(recipe_type, kwargs)
122+
123+
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
124+
"""Get expected parameter keys for a recipe type"""
125+
common_keys = {"minimum_deployment_target", "compute_unit"}
126+
127+
if recipe_type in [
128+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
129+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
130+
]:
131+
return common_keys | {"group_size", "filter_fn"}
132+
elif recipe_type in [
133+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL,
134+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL,
135+
]:
136+
return common_keys | {"filter_fn"}
137+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
138+
return common_keys | {"bits", "block_size", "filter_fn"}
139+
else:
140+
return common_keys
141+
142+
def _validate_base_parameters(self, kwargs: Any) -> None:
143+
"""Validate minimum_deployment_target and compute_unit parameters"""
75144
if "minimum_deployment_target" in kwargs:
76145
minimum_deployment_target = kwargs["minimum_deployment_target"]
77146
if not isinstance(minimum_deployment_target, ct.target):
78147
raise ValueError(
79148
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
80149
)
150+
81151
if "compute_unit" in kwargs:
82152
compute_unit = kwargs["compute_unit"]
83153
if not isinstance(compute_unit, ct.ComputeUnit):
84154
raise ValueError(
85155
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
86156
)
87157

88-
def _build_recipe(
158+
def _validate_group_size_parameter(
159+
self, recipe_type: RecipeType, kwargs: Any
160+
) -> None:
161+
"""Validate group_size parameter for applicable recipe types"""
162+
if (
163+
recipe_type
164+
in [
165+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
166+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
167+
]
168+
and "group_size" in kwargs
169+
):
170+
group_size = kwargs["group_size"]
171+
if not isinstance(group_size, int):
172+
raise ValueError(
173+
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
174+
)
175+
if group_size <= 0:
176+
raise ValueError(
177+
f"Parameter 'group_size' must be positive, got: {group_size}"
178+
)
179+
180+
def _validate_codebook_parameters(
181+
self, recipe_type: RecipeType, kwargs: Any
182+
) -> None:
183+
"""Validate bits and block_size parameters for codebook recipe type"""
184+
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
185+
return
186+
187+
if "bits" in kwargs:
188+
bits = kwargs["bits"]
189+
if not isinstance(bits, int):
190+
raise ValueError(
191+
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
192+
)
193+
if not (1 <= bits <= 8):
194+
raise ValueError(
195+
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
196+
)
197+
198+
if "block_size" in kwargs:
199+
block_size = kwargs["block_size"]
200+
if not isinstance(block_size, list):
201+
raise ValueError(
202+
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
203+
)
204+
205+
def _build_fp_recipe(
89206
self,
90207
recipe_type: RecipeType,
91208
precision: ct.precision,
92209
**kwargs: Any,
93210
) -> ExportRecipe:
211+
"""Build FP32/FP16 recipe"""
94212
lowering_recipe = self._get_coreml_lowering_recipe(
95213
compute_precision=precision,
96214
**kwargs,
97215
)
98216

99217
return ExportRecipe(
100218
name=recipe_type.value,
101-
quantization_recipe=None, # TODO - add quantization recipe
219+
lowering_recipe=lowering_recipe,
220+
)
221+
222+
def _build_pt2e_quantized_recipe(
223+
self,
224+
recipe_type: RecipeType,
225+
activation_dtype: torch.dtype,
226+
**kwargs: Any,
227+
) -> ExportRecipe:
228+
"""Build PT2E-based quantization recipe"""
229+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
230+
231+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
232+
if minimum_deployment_target and minimum_deployment_target < ct.target.iOS17:
233+
raise ValueError(
234+
"minimum_deployment_target must be iOS17 or higher for codebook quantization"
235+
)
236+
# Default to iOS17 for quantization
237+
kwargs["minimum_deployment_target"] = ct.target.iOS17
238+
239+
# Validate activation_dtype
240+
assert activation_dtype in [
241+
torch.quint8,
242+
torch.float32,
243+
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"
244+
245+
# Create quantization config
246+
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
247+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
248+
quantization_scheme="symmetric",
249+
activation_dtype=activation_dtype,
250+
weight_dtype=torch.qint8,
251+
weight_per_channel=True,
252+
)
253+
)
254+
255+
quantizer = CoreMLQuantizer(config)
256+
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])
257+
258+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
259+
260+
return ExportRecipe(
261+
name=recipe_type.value,
262+
quantization_recipe=quantization_recipe,
263+
lowering_recipe=lowering_recipe,
264+
)
265+
266+
def _build_torchao_quantized_recipe(
267+
self,
268+
recipe_type: RecipeType,
269+
weight_dtype: torch.dtype,
270+
is_per_channel: bool,
271+
group_size: int = 32,
272+
**kwargs: Any,
273+
) -> ExportRecipe:
274+
"""Build TorchAO-based quantization recipe"""
275+
if is_per_channel:
276+
weight_granularity = PerAxis(axis=0)
277+
else:
278+
weight_granularity = PerGroup(group_size=group_size)
279+
280+
# Use user-provided filter_fn or default to Linear/Embedding layers
281+
filter_fn = kwargs.get("filter_fn", None)
282+
config = AOQuantizationConfig(
283+
ao_base_config=IntxWeightOnlyConfig(
284+
weight_dtype=weight_dtype,
285+
granularity=weight_granularity,
286+
),
287+
filter_fn=filter_fn,
288+
)
289+
290+
quantization_recipe = QuantizationRecipe(
291+
quantizers=None,
292+
ao_quantization_configs=[config],
293+
)
294+
295+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
296+
kwargs["minimum_deployment_target"] = ct.target.iOS18
297+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
298+
299+
return ExportRecipe(
300+
name=recipe_type.value,
301+
quantization_recipe=quantization_recipe,
302+
lowering_recipe=lowering_recipe,
303+
)
304+
305+
def _build_codebook_quantized_recipe(
306+
self,
307+
recipe_type: RecipeType,
308+
bits: int,
309+
block_size: list,
310+
**kwargs: Any,
311+
) -> ExportRecipe:
312+
"""Build codebook/palettization quantization recipe"""
313+
from torchao.prototype.quantization.codebook_coreml import (
314+
CodebookWeightOnlyConfig,
315+
)
316+
317+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
318+
if minimum_deployment_target and minimum_deployment_target < ct.target.iOS18:
319+
raise ValueError(
320+
"minimum_deployment_target must be iOS18 or higher for codebook quantization"
321+
)
322+
# Default to iOS18 for codebook quantization
323+
kwargs["minimum_deployment_target"] = ct.target.iOS18
324+
325+
# Get the appropriate dtype (torch.uint1 through torch.uint8)
326+
dtype = getattr(torch, f"uint{bits}")
327+
328+
# Use user-provided filter_fn or default to Linear/Embedding layers
329+
filter_fn = kwargs.get(
330+
"filter_fn",
331+
lambda m, _: (
332+
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
333+
),
334+
)
335+
336+
config = AOQuantizationConfig(
337+
ao_base_config=CodebookWeightOnlyConfig(
338+
dtype=dtype,
339+
block_size=block_size,
340+
),
341+
filter_fn=filter_fn,
342+
)
343+
344+
quantization_recipe = QuantizationRecipe(
345+
quantizers=None,
346+
ao_quantization_configs=[config],
347+
)
348+
349+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
350+
351+
return ExportRecipe(
352+
name=recipe_type.value,
353+
quantization_recipe=quantization_recipe,
102354
lowering_recipe=lowering_recipe,
103355
)
104356

105357
def _get_coreml_lowering_recipe(
106358
self,
107-
compute_precision: ct.precision,
359+
compute_precision: ct.precision = ct.precision.FLOAT16,
108360
**kwargs: Any,
109361
) -> LoweringRecipe:
362+
"""Get CoreML lowering recipe with optional precision"""
110363
compile_specs = CoreMLBackend.generate_compile_specs(
111364
compute_precision=compute_precision,
112-
**kwargs,
365+
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
366+
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
113367
)
114368

115369
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)

0 commit comments

Comments
 (0)