Skip to content

Commit 14b0088

Browse files
authored
Revert "Add TorchAO wrapper config to allow filter_fn for quantize_ (#13264)" and "Add coreml quant recipes (#13265)" (#13374)
This reverts commit 0a7cea8 and 310a05d. It appears that #13264 broke unittest jobs and #13265 depends on it.
1 parent ea4a7fa commit 14b0088

File tree

12 files changed

+258
-1033
lines changed

12 files changed

+258
-1033
lines changed

backends/apple/coreml/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ runtime.python_test(
120120
"test/*.py",
121121
]),
122122
deps = [
123-
"fbsource//third-party/pypi/coremltools:coremltools",
124123
"fbsource//third-party/pypi/pytest:pytest",
125124
":partitioner",
126125
":quantizer",

backends/apple/coreml/recipes/coreml_recipe_provider.py

Lines changed: 17 additions & 277 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Optional, Sequence
77

88
import coremltools as ct
9-
import torch
109

1110
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1211
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
@@ -19,15 +18,11 @@
1918

2019
from executorch.exir import EdgeCompileConfig
2120
from executorch.export import (
22-
AOQuantizationConfig,
2321
BackendRecipeProvider,
2422
ExportRecipe,
2523
LoweringRecipe,
26-
QuantizationRecipe,
2724
RecipeType,
2825
)
29-
from torchao.quantization.granularity import PerAxis, PerGroup
30-
from torchao.quantization.quant_api import IntxWeightOnlyConfig
3126

3227

3328
class CoreMLRecipeProvider(BackendRecipeProvider):
@@ -55,321 +50,66 @@ def create_recipe(
5550
# Validate kwargs
5651
self._validate_recipe_kwargs(recipe_type, **kwargs)
5752

53+
# Parse recipe type to get precision and compute unit
54+
precision = None
5855
if recipe_type == CoreMLRecipeType.FP32:
59-
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
56+
precision = ct.precision.FLOAT32
6057
elif recipe_type == CoreMLRecipeType.FP16:
61-
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
62-
elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC:
63-
return self._build_pt2e_quantized_recipe(
64-
recipe_type, activation_dtype=torch.quint8, **kwargs
65-
)
66-
elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY:
67-
return self._build_pt2e_quantized_recipe(
68-
recipe_type, activation_dtype=torch.float32, **kwargs
69-
)
70-
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL:
71-
return self._build_torchao_quantized_recipe(
72-
recipe_type,
73-
weight_dtype=torch.int4,
74-
is_per_channel=True,
75-
**kwargs,
76-
)
77-
elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP:
78-
group_size = kwargs.pop("group_size", 32)
79-
return self._build_torchao_quantized_recipe(
80-
recipe_type,
81-
weight_dtype=torch.int4,
82-
is_per_channel=False,
83-
group_size=group_size,
84-
**kwargs,
85-
)
86-
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL:
87-
return self._build_torchao_quantized_recipe(
88-
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
89-
)
90-
elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP:
91-
group_size = kwargs.pop("group_size", 32)
92-
return self._build_torchao_quantized_recipe(
93-
recipe_type,
94-
weight_dtype=torch.int8,
95-
is_per_channel=False,
96-
group_size=group_size,
97-
**kwargs,
98-
)
99-
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
100-
bits = kwargs.pop("bits")
101-
block_size = kwargs.pop("block_size")
102-
return self._build_codebook_quantized_recipe(
103-
recipe_type, bits=bits, block_size=block_size, **kwargs
104-
)
58+
precision = ct.precision.FLOAT16
10559

106-
return None
60+
if precision is None:
61+
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
10762

108-
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
109-
"""Validate kwargs for each recipe type"""
110-
expected_keys = self._get_expected_keys(recipe_type)
63+
return self._build_recipe(recipe_type, precision, **kwargs)
11164

65+
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
66+
if not kwargs:
67+
return
68+
expected_keys = {"minimum_deployment_target", "compute_unit"}
11269
unexpected = set(kwargs.keys()) - expected_keys
11370
if unexpected:
11471
raise ValueError(
115-
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
72+
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
73+
f"Unexpected parameters: {list(unexpected)}"
11674
)
117-
118-
self._validate_base_parameters(kwargs)
119-
self._validate_group_size_parameter(recipe_type, kwargs)
120-
self._validate_codebook_parameters(recipe_type, kwargs)
121-
122-
def _get_expected_keys(self, recipe_type: RecipeType) -> set:
123-
"""Get expected parameter keys for a recipe type"""
124-
common_keys = {"minimum_deployment_target", "compute_unit"}
125-
126-
if recipe_type in [
127-
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
128-
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,
129-
]:
130-
return common_keys | {"group_size", "filter_fn"}
131-
elif recipe_type in [
132-
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL,
133-
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL,
134-
]:
135-
return common_keys | {"filter_fn"}
136-
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
137-
return common_keys | {"bits", "block_size", "filter_fn"}
138-
else:
139-
return common_keys
140-
141-
def _validate_base_parameters(self, kwargs: Any) -> None:
142-
"""Validate minimum_deployment_target and compute_unit parameters"""
14375
if "minimum_deployment_target" in kwargs:
14476
minimum_deployment_target = kwargs["minimum_deployment_target"]
14577
if not isinstance(minimum_deployment_target, ct.target):
14678
raise ValueError(
14779
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
14880
)
149-
15081
if "compute_unit" in kwargs:
15182
compute_unit = kwargs["compute_unit"]
15283
if not isinstance(compute_unit, ct.ComputeUnit):
15384
raise ValueError(
15485
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
15586
)
15687

157-
def _validate_group_size_parameter(
158-
self, recipe_type: RecipeType, kwargs: Any
159-
) -> None:
160-
"""Validate group_size parameter for applicable recipe types"""
161-
if (
162-
recipe_type
163-
in [
164-
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
165-
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,
166-
]
167-
and "group_size" in kwargs
168-
):
169-
group_size = kwargs["group_size"]
170-
if not isinstance(group_size, int):
171-
raise ValueError(
172-
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
173-
)
174-
if group_size <= 0:
175-
raise ValueError(
176-
f"Parameter 'group_size' must be positive, got: {group_size}"
177-
)
178-
179-
def _validate_codebook_parameters(
180-
self, recipe_type: RecipeType, kwargs: Any
181-
) -> None:
182-
"""Validate bits and block_size parameters for codebook recipe type"""
183-
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
184-
return
185-
186-
# Both bits and block_size must be present
187-
if not ("bits" in kwargs and "block_size" in kwargs):
188-
raise ValueError(
189-
"Parameters 'bits' and 'block_size' must be present for codebook recipes"
190-
)
191-
192-
if "bits" in kwargs:
193-
bits = kwargs["bits"]
194-
if not isinstance(bits, int):
195-
raise ValueError(
196-
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
197-
)
198-
if not (1 <= bits <= 8):
199-
raise ValueError(
200-
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
201-
)
202-
203-
if "block_size" in kwargs:
204-
block_size = kwargs["block_size"]
205-
if not isinstance(block_size, list):
206-
raise ValueError(
207-
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
208-
)
209-
210-
def _validate_and_set_deployment_target(
211-
self, kwargs: Any, min_target: ct.target, quantization_type: str
212-
) -> None:
213-
"""Validate or set minimum deployment target for quantization recipes"""
214-
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
215-
if minimum_deployment_target and minimum_deployment_target < min_target:
216-
raise ValueError(
217-
f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization"
218-
)
219-
else:
220-
# Default to the minimum target for this quantization type
221-
kwargs["minimum_deployment_target"] = min_target
222-
223-
def _build_fp_recipe(
88+
def _build_recipe(
22489
self,
22590
recipe_type: RecipeType,
22691
precision: ct.precision,
22792
**kwargs: Any,
22893
) -> ExportRecipe:
229-
"""Build FP32/FP16 recipe"""
23094
lowering_recipe = self._get_coreml_lowering_recipe(
23195
compute_precision=precision,
23296
**kwargs,
23397
)
23498

23599
return ExportRecipe(
236100
name=recipe_type.value,
237-
lowering_recipe=lowering_recipe,
238-
)
239-
240-
def _build_pt2e_quantized_recipe(
241-
self,
242-
recipe_type: RecipeType,
243-
activation_dtype: torch.dtype,
244-
**kwargs: Any,
245-
) -> ExportRecipe:
246-
"""Build PT2E-based quantization recipe"""
247-
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
248-
249-
self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e")
250-
251-
# Validate activation_dtype
252-
assert activation_dtype in [
253-
torch.quint8,
254-
torch.float32,
255-
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"
256-
257-
# Create quantization config
258-
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
259-
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
260-
quantization_scheme="symmetric",
261-
activation_dtype=activation_dtype,
262-
weight_dtype=torch.qint8,
263-
weight_per_channel=True,
264-
)
265-
)
266-
267-
quantizer = CoreMLQuantizer(config)
268-
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])
269-
270-
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
271-
272-
return ExportRecipe(
273-
name=recipe_type.value,
274-
quantization_recipe=quantization_recipe,
275-
lowering_recipe=lowering_recipe,
276-
)
277-
278-
def _build_torchao_quantized_recipe(
279-
self,
280-
recipe_type: RecipeType,
281-
weight_dtype: torch.dtype,
282-
is_per_channel: bool,
283-
group_size: int = 32,
284-
**kwargs: Any,
285-
) -> ExportRecipe:
286-
"""Build TorchAO-based quantization recipe"""
287-
if is_per_channel:
288-
weight_granularity = PerAxis(axis=0)
289-
else:
290-
weight_granularity = PerGroup(group_size=group_size)
291-
292-
# Use user-provided filter_fn if provided
293-
filter_fn = kwargs.get("filter_fn", None)
294-
config = AOQuantizationConfig(
295-
ao_base_config=IntxWeightOnlyConfig(
296-
weight_dtype=weight_dtype,
297-
granularity=weight_granularity,
298-
),
299-
filter_fn=filter_fn,
300-
)
301-
302-
quantization_recipe = QuantizationRecipe(
303-
quantizers=None,
304-
ao_quantization_configs=[config],
305-
)
306-
307-
# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
308-
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
309-
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
310-
311-
return ExportRecipe(
312-
name=recipe_type.value,
313-
quantization_recipe=quantization_recipe,
314-
lowering_recipe=lowering_recipe,
315-
)
316-
317-
def _build_codebook_quantized_recipe(
318-
self,
319-
recipe_type: RecipeType,
320-
bits: int,
321-
block_size: list,
322-
**kwargs: Any,
323-
) -> ExportRecipe:
324-
"""Build codebook/palettization quantization recipe"""
325-
from torchao.prototype.quantization.codebook_coreml import (
326-
CodebookWeightOnlyConfig,
327-
)
328-
329-
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook")
330-
331-
# Get the appropriate dtype (torch.uint1 through torch.uint8)
332-
dtype = getattr(torch, f"uint{bits}")
333-
334-
# Use user-provided filter_fn or default to Linear/Embedding layers
335-
filter_fn = kwargs.get(
336-
"filter_fn",
337-
lambda m, fqn: (
338-
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
339-
),
340-
)
341-
342-
config = AOQuantizationConfig(
343-
ao_base_config=CodebookWeightOnlyConfig(
344-
dtype=dtype,
345-
block_size=block_size,
346-
),
347-
filter_fn=filter_fn,
348-
)
349-
350-
quantization_recipe = QuantizationRecipe(
351-
quantizers=None,
352-
ao_quantization_configs=[config],
353-
)
354-
355-
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)
356-
357-
return ExportRecipe(
358-
name=recipe_type.value,
359-
quantization_recipe=quantization_recipe,
101+
quantization_recipe=None, # TODO - add quantization recipe
360102
lowering_recipe=lowering_recipe,
361103
)
362104

363105
def _get_coreml_lowering_recipe(
364106
self,
365-
compute_precision: ct.precision = ct.precision.FLOAT16,
107+
compute_precision: ct.precision,
366108
**kwargs: Any,
367109
) -> LoweringRecipe:
368-
"""Get CoreML lowering recipe with optional precision"""
369110
compile_specs = CoreMLBackend.generate_compile_specs(
370111
compute_precision=compute_precision,
371-
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
372-
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
112+
**kwargs,
373113
)
374114

375115
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)

0 commit comments

Comments
 (0)