Skip to content

Commit 7658344

Browse files
Add coreml quant recipes
ghstack-source-id: c72ea76 ghstack-comment-id: 3172341606 Pull-Request: #13265
1 parent 164fad0 commit 7658344

File tree

4 files changed

+812
-173
lines changed

4 files changed

+812
-173
lines changed

backends/apple/coreml/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ runtime.python_test(
120120
"test/*.py",
121121
]),
122122
deps = [
123+
"fbsource//third-party/pypi/coremltools:coremltools",
123124
"fbsource//third-party/pypi/pytest:pytest",
124125
":partitioner",
125126
":quantizer",

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 271 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Optional, Sequence
77

88
import coremltools as ct
9+
import torch
910

1011
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1112
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
@@ -18,11 +19,15 @@
1819

1920
from executorch.exir import EdgeCompileConfig
2021
from executorch.export import (
22+
AOQuantizationConfig,
2123
BackendRecipeProvider,
2224
ExportRecipe,
2325
LoweringRecipe,
26+
QuantizationRecipe,
2427
RecipeType,
2528
)
29+
from torchao.quantization.granularity import PerAxis, PerGroup
30+
from torchao.quantization.quant_api import IntxWeightOnlyConfig
2631

2732

2833
class CoreMLRecipeProvider(BackendRecipeProvider):
@@ -50,66 +55,315 @@ def create_recipe(
5055
# Validate kwargs
5156
self._validate_recipe_kwargs(recipe_type, **kwargs)
5257

53-
# Parse recipe type to get precision and compute unit
54-
precision = None
5558
if recipe_type == CoreMLRecipeType.FP32:
56-
precision = ct.precision.FLOAT32
59+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
5760
elif recipe_type == CoreMLRecipeType.FP16:
58-
precision = ct.precision.FLOAT16
59-
60-
if precision is None:
61-
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
61+
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
62+
elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC:
63+
return self._build_pt2e_quantized_recipe(
64+
recipe_type, activation_dtype=torch.quint8, **kwargs
65+
)
66+
elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY:
67+
return self._build_pt2e_quantized_recipe(
68+
recipe_type, activation_dtype=torch.float32, **kwargs
69+
)
70+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL:
71+
return self._build_torchao_quantized_recipe(
72+
recipe_type,
73+
weight_dtype=torch.int4,
74+
is_per_channel=True,
75+
**kwargs,
76+
)
77+
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP:
78+
group_size = kwargs.pop("group_size", 32)
79+
return self._build_torchao_quantized_recipe(
80+
recipe_type,
81+
weight_dtype=torch.int4,
82+
is_per_channel=False,
83+
group_size=group_size,
84+
**kwargs,
85+
)
86+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL:
87+
return self._build_torchao_quantized_recipe(
88+
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
89+
)
90+
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP:
91+
group_size = kwargs.pop("group_size", 32)
92+
return self._build_torchao_quantized_recipe(
93+
recipe_type,
94+
weight_dtype=torch.int8,
95+
is_per_channel=False,
96+
group_size=group_size,
97+
**kwargs,
98+
)
99+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
100+
bits = kwargs.pop("bits", 3)
101+
block_size = kwargs.pop("block_size", [-1, 16])
102+
return self._build_codebook_quantized_recipe(
103+
recipe_type, bits=bits, block_size=block_size, **kwargs
104+
)
62105

63-
return self._build_recipe(recipe_type, precision, **kwargs)
106+
return None
64107

65108
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
66-
if not kwargs:
67-
return
68-
expected_keys = {"minimum_deployment_target", "compute_unit"}
109+
"""Validate kwargs for each recipe type"""
110+
expected_keys = self._get_expected_keys(recipe_type)
111+
69112
unexpected = set(kwargs.keys()) - expected_keys
70113
if unexpected:
71114
raise ValueError(
72-
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
73-
f"Unexpected parameters: {list(unexpected)}"
115+
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
74116
)
117+
118+
self._validate_base_parameters(kwargs)
119+
self._validate_group_size_parameter(recipe_type, kwargs)
120+
self._validate_codebook_parameters(recipe_type, kwargs)
121+
122+
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
123+
"""Get expected parameter keys for a recipe type"""
124+
common_keys = {"minimum_deployment_target", "compute_unit"}
125+
126+
if recipe_type in [
127+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
128+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
129+
]:
130+
return common_keys | {"group_size", "filter_fn"}
131+
elif recipe_type in [
132+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL,
133+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL,
134+
]:
135+
return common_keys | {"filter_fn"}
136+
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
137+
return common_keys | {"bits", "block_size", "filter_fn"}
138+
else:
139+
return common_keys
140+
141+
def _validate_base_parameters(self, kwargs: Any) -> None:
142+
"""Validate minimum_deployment_target and compute_unit parameters"""
75143
if "minimum_deployment_target" in kwargs:
76144
minimum_deployment_target = kwargs["minimum_deployment_target"]
77145
if not isinstance(minimum_deployment_target, ct.target):
78146
raise ValueError(
79147
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
80148
)
149+
81150
if "compute_unit" in kwargs:
82151
compute_unit = kwargs["compute_unit"]
83152
if not isinstance(compute_unit, ct.ComputeUnit):
84153
raise ValueError(
85154
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
86155
)
87156

88-
def _build_recipe(
157+
def _validate_group_size_parameter(
158+
self, recipe_type: RecipeType, kwargs: Any
159+
) -> None:
160+
"""Validate group_size parameter for applicable recipe types"""
161+
if (
162+
recipe_type
163+
in [
164+
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
165+
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
166+
]
167+
and "group_size" in kwargs
168+
):
169+
group_size = kwargs["group_size"]
170+
if not isinstance(group_size, int):
171+
raise ValueError(
172+
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
173+
)
174+
if group_size <= 0:
175+
raise ValueError(
176+
f"Parameter 'group_size' must be positive, got: {group_size}"
177+
)
178+
179+
def _validate_codebook_parameters(
180+
self, recipe_type: RecipeType, kwargs: Any
181+
) -> None:
182+
"""Validate bits and block_size parameters for codebook recipe type"""
183+
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
184+
return
185+
186+
if "bits" in kwargs:
187+
bits = kwargs["bits"]
188+
if not isinstance(bits, int):
189+
raise ValueError(
190+
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
191+
)
192+
if not (1 <= bits <= 8):
193+
raise ValueError(
194+
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
195+
)
196+
197+
if "block_size" in kwargs:
198+
block_size = kwargs["block_size"]
199+
if not isinstance(block_size, list):
200+
raise ValueError(
201+
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
202+
)
203+
204+
def _validate_and_set_deployment_target(
205+
self, kwargs: Any, min_target: ct.target, quantization_type: str
206+
) -> None:
207+
"""Validate or set minimum deployment target for quantization recipes"""
208+
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
209+
if minimum_deployment_target and minimum_deployment_target < min_target:
210+
raise ValueError(
211+
f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization"
212+
)
213+
else:
214+
# Default to the minimum target for this quantization type
215+
kwargs["minimum_deployment_target"] = min_target
216+
217+
def _build_fp_recipe(
89218
self,
90219
recipe_type: RecipeType,
91220
precision: ct.precision,
92221
**kwargs: Any,
93222
) -> ExportRecipe:
223+
"""Build FP32/FP16 recipe"""
94224
lowering_recipe = self._get_coreml_lowering_recipe(
95225
compute_precision=precision,
96226
**kwargs,
97227
)
98228

99229
return ExportRecipe(
100230
name=recipe_type.value,
101-
quantization_recipe=None, # TODO - add quantization recipe
231+
lowering_recipe=lowering_recipe,
232+
)
233+
234+
def _build_pt2e_quantized_recipe(
235+
self,
236+
recipe_type: RecipeType,
237+
activation_dtype: torch.dtype,
238+
**kwargs: Any,
239+
) -> ExportRecipe:
240+
"""Build PT2E-based quantization recipe"""
241+
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
242+
243+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e")
244+
245+
# Validate activation_dtype
246+
assert activation_dtype in [
247+
torch.quint8,
248+
torch.float32,
249+
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"
250+
251+
# Create quantization config
252+
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
253+
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
254+
quantization_scheme="symmetric",
255+
activation_dtype=activation_dtype,
256+
weight_dtype=torch.qint8,
257+
weight_per_channel=True,
258+
)
259+
)
260+
261+
quantizer = CoreMLQuantizer(config)
262+
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])
263+
264+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
265+
266+
return ExportRecipe(
267+
name=recipe_type.value,
268+
quantization_recipe=quantization_recipe,
269+
lowering_recipe=lowering_recipe,
270+
)
271+
272+
def _build_torchao_quantized_recipe(
273+
self,
274+
recipe_type: RecipeType,
275+
weight_dtype: torch.dtype,
276+
is_per_channel: bool,
277+
group_size: int = 32,
278+
**kwargs: Any,
279+
) -> ExportRecipe:
280+
"""Build TorchAO-based quantization recipe"""
281+
if is_per_channel:
282+
weight_granularity = PerAxis(axis=0)
283+
else:
284+
weight_granularity = PerGroup(group_size=group_size)
285+
286+
# Use user-provided filter_fn if provided
287+
filter_fn = kwargs.get("filter_fn", None)
288+
config = AOQuantizationConfig(
289+
ao_base_config=IntxWeightOnlyConfig(
290+
weight_dtype=weight_dtype,
291+
granularity=weight_granularity,
292+
),
293+
filter_fn=filter_fn,
294+
)
295+
296+
quantization_recipe = QuantizationRecipe(
297+
quantizers=None,
298+
ao_quantization_configs=[config],
299+
)
300+
301+
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
302+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
303+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
304+
305+
return ExportRecipe(
306+
name=recipe_type.value,
307+
quantization_recipe=quantization_recipe,
308+
lowering_recipe=lowering_recipe,
309+
)
310+
311+
def _build_codebook_quantized_recipe(
312+
self,
313+
recipe_type: RecipeType,
314+
bits: int,
315+
block_size: list,
316+
**kwargs: Any,
317+
) -> ExportRecipe:
318+
"""Build codebook/palettization quantization recipe"""
319+
from torchao.prototype.quantization.codebook_coreml import (
320+
CodebookWeightOnlyConfig,
321+
)
322+
323+
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook")
324+
325+
# Get the appropriate dtype (torch.uint1 through torch.uint8)
326+
dtype = getattr(torch, f"uint{bits}")
327+
328+
# Use user-provided filter_fn or default to Linear/Embedding layers
329+
filter_fn = kwargs.get(
330+
"filter_fn",
331+
lambda m, fqn: (
332+
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
333+
),
334+
)
335+
336+
config = AOQuantizationConfig(
337+
ao_base_config=CodebookWeightOnlyConfig(
338+
dtype=dtype,
339+
block_size=block_size,
340+
),
341+
filter_fn=filter_fn,
342+
)
343+
344+
quantization_recipe = QuantizationRecipe(
345+
quantizers=None,
346+
ao_quantization_configs=[config],
347+
)
348+
349+
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
350+
351+
return ExportRecipe(
352+
name=recipe_type.value,
353+
quantization_recipe=quantization_recipe,
102354
lowering_recipe=lowering_recipe,
103355
)
104356

105357
def _get_coreml_lowering_recipe(
106358
self,
107-
compute_precision: ct.precision,
359+
compute_precision: ct.precision = ct.precision.FLOAT16,
108360
**kwargs: Any,
109361
) -> LoweringRecipe:
362+
"""Get CoreML lowering recipe with optional precision"""
110363
compile_specs = CoreMLBackend.generate_compile_specs(
111364
compute_precision=compute_precision,
112-
**kwargs,
365+
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
366+
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
113367
)
114368

115369
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)

0 commit comments

Comments
 (0)