Skip to content

Add coreml quant recipes #13265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/abhinaykukkadapu/4/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/apple/coreml/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ runtime.python_test(
"test/*.py",
]),
deps = [
"fbsource//third-party/pypi/coremltools:coremltools",
"fbsource//third-party/pypi/pytest:pytest",
":partitioner",
":quantizer",
Expand Down
288 changes: 271 additions & 17 deletions backends/apple/coreml/recipes/coreml_recipe_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Optional, Sequence

import coremltools as ct
import torch

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition.coreml_partitioner import (
Expand All @@ -18,11 +19,15 @@

from executorch.exir import EdgeCompileConfig
from executorch.export import (
AOQuantizationConfig,
BackendRecipeProvider,
ExportRecipe,
LoweringRecipe,
QuantizationRecipe,
RecipeType,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import IntxWeightOnlyConfig


class CoreMLRecipeProvider(BackendRecipeProvider):
Expand Down Expand Up @@ -50,66 +55,315 @@ def create_recipe(
# Validate kwargs
self._validate_recipe_kwargs(recipe_type, **kwargs)

# Parse recipe type to get precision and compute unit
precision = None
if recipe_type == CoreMLRecipeType.FP32:
precision = ct.precision.FLOAT32
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs)
elif recipe_type == CoreMLRecipeType.FP16:
precision = ct.precision.FLOAT16

if precision is None:
raise ValueError(f"Unknown precision for recipe: {recipe_type.value}")
return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs)
elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC:
return self._build_pt2e_quantized_recipe(
recipe_type, activation_dtype=torch.quint8, **kwargs
)
elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY:
return self._build_pt2e_quantized_recipe(
recipe_type, activation_dtype=torch.float32, **kwargs
)
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL:
return self._build_torchao_quantized_recipe(
recipe_type,
weight_dtype=torch.int4,
is_per_channel=True,
**kwargs,
)
elif recipe_type == CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP:
group_size = kwargs.pop("group_size", 32)
return self._build_torchao_quantized_recipe(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the compute precision used in these quantization recipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy It is FLOAT16, see _get_coreml_lowering_recipe function defaults.

recipe_type,
weight_dtype=torch.int4,
is_per_channel=False,
group_size=group_size,
**kwargs,
)
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL:
return self._build_torchao_quantized_recipe(
recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs
)
elif recipe_type == CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP:
group_size = kwargs.pop("group_size", 32)
return self._build_torchao_quantized_recipe(
recipe_type,
weight_dtype=torch.int8,
is_per_channel=False,
group_size=group_size,
**kwargs,
)
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
bits = kwargs.pop("bits", 3)
block_size = kwargs.pop("block_size", [-1, 16])
return self._build_codebook_quantized_recipe(
recipe_type, bits=bits, block_size=block_size, **kwargs
)

return self._build_recipe(recipe_type, precision, **kwargs)
return None

def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
if not kwargs:
return
expected_keys = {"minimum_deployment_target", "compute_unit"}
"""Validate kwargs for each recipe type"""
expected_keys = self._get_expected_keys(recipe_type)

unexpected = set(kwargs.keys()) - expected_keys
if unexpected:
raise ValueError(
f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. "
f"Unexpected parameters: {list(unexpected)}"
f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}"
)

self._validate_base_parameters(kwargs)
self._validate_group_size_parameter(recipe_type, kwargs)
self._validate_codebook_parameters(recipe_type, kwargs)

def _get_expected_keys(self, recipe_type: RecipeType) -> set:
"""Get expected parameter keys for a recipe type"""
common_keys = {"minimum_deployment_target", "compute_unit"}

if recipe_type in [
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
]:
return common_keys | {"group_size", "filter_fn"}
elif recipe_type in [
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_CHANNEL,
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_CHANNEL,
]:
return common_keys | {"filter_fn"}
elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
return common_keys | {"bits", "block_size", "filter_fn"}
else:
return common_keys

def _validate_base_parameters(self, kwargs: Any) -> None:
"""Validate minimum_deployment_target and compute_unit parameters"""
if "minimum_deployment_target" in kwargs:
minimum_deployment_target = kwargs["minimum_deployment_target"]
if not isinstance(minimum_deployment_target, ct.target):
raise ValueError(
f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}"
)

if "compute_unit" in kwargs:
compute_unit = kwargs["compute_unit"]
if not isinstance(compute_unit, ct.ComputeUnit):
raise ValueError(
f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}"
)

def _build_recipe(
def _validate_group_size_parameter(
self, recipe_type: RecipeType, kwargs: Any
) -> None:
"""Validate group_size parameter for applicable recipe types"""
if (
recipe_type
in [
CoreMLRecipeType.INT4_WEIGHT_ONLY_PER_GROUP,
CoreMLRecipeType.INT8_WEIGHT_ONLY_PER_GROUP,
]
and "group_size" in kwargs
):
group_size = kwargs["group_size"]
if not isinstance(group_size, int):
raise ValueError(
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
)
if group_size <= 0:
raise ValueError(
f"Parameter 'group_size' must be positive, got: {group_size}"
)

def _validate_codebook_parameters(
self, recipe_type: RecipeType, kwargs: Any
) -> None:
"""Validate bits and block_size parameters for codebook recipe type"""
if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY:
return

if "bits" in kwargs:
bits = kwargs["bits"]
if not isinstance(bits, int):
raise ValueError(
f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}"
)
if not (1 <= bits <= 8):
raise ValueError(
f"Parameter 'bits' must be between 1 and 8, got: {bits}"
)

if "block_size" in kwargs:
block_size = kwargs["block_size"]
if not isinstance(block_size, list):
raise ValueError(
f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}"
)

def _validate_and_set_deployment_target(
self, kwargs: Any, min_target: ct.target, quantization_type: str
) -> None:
"""Validate or set minimum deployment target for quantization recipes"""
minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
if minimum_deployment_target and minimum_deployment_target < min_target:
raise ValueError(
f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization"
)
else:
# Default to the minimum target for this quantization type
kwargs["minimum_deployment_target"] = min_target

def _build_fp_recipe(
self,
recipe_type: RecipeType,
precision: ct.precision,
**kwargs: Any,
) -> ExportRecipe:
"""Build FP32/FP16 recipe"""
lowering_recipe = self._get_coreml_lowering_recipe(
compute_precision=precision,
**kwargs,
)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=None, # TODO - add quantization recipe
lowering_recipe=lowering_recipe,
)

def _build_pt2e_quantized_recipe(
self,
recipe_type: RecipeType,
activation_dtype: torch.dtype,
**kwargs: Any,
) -> ExportRecipe:
"""Build PT2E-based quantization recipe"""
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer

self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e")

# Validate activation_dtype
assert activation_dtype in [
torch.quint8,
torch.float32,
], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}"

# Create quantization config
config = ct.optimize.torch.quantization.LinearQuantizerConfig(
global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
quantization_scheme="symmetric",
activation_dtype=activation_dtype,
weight_dtype=torch.qint8,
weight_per_channel=True,
)
)

quantizer = CoreMLQuantizer(config)
quantization_recipe = QuantizationRecipe(quantizers=[quantizer])

lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quantization_recipe,
lowering_recipe=lowering_recipe,
)

def _build_torchao_quantized_recipe(
self,
recipe_type: RecipeType,
weight_dtype: torch.dtype,
is_per_channel: bool,
group_size: int = 32,
**kwargs: Any,
) -> ExportRecipe:
"""Build TorchAO-based quantization recipe"""
if is_per_channel:
weight_granularity = PerAxis(axis=0)
else:
weight_granularity = PerGroup(group_size=group_size)

# Use user-provided filter_fn if provided
filter_fn = kwargs.get("filter_fn", None)
config = AOQuantizationConfig(
ao_base_config=IntxWeightOnlyConfig(
weight_dtype=weight_dtype,
granularity=weight_granularity,
),
filter_fn=filter_fn,
)

quantization_recipe = QuantizationRecipe(
quantizers=None,
ao_quantization_configs=[config],
)

# override minimum_deployment_target to ios18 for torchao (GH issue #13122)
self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao")
lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quantization_recipe,
lowering_recipe=lowering_recipe,
)

def _build_codebook_quantized_recipe(
self,
recipe_type: RecipeType,
bits: int,
block_size: list,
**kwargs: Any,
) -> ExportRecipe:
"""Build codebook/palettization quantization recipe"""
from torchao.prototype.quantization.codebook_coreml import (
CodebookWeightOnlyConfig,
)

self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook")

# Get the appropriate dtype (torch.uint1 through torch.uint8)
dtype = getattr(torch, f"uint{bits}")

# Use user-provided filter_fn or default to Linear/Embedding layers
filter_fn = kwargs.get(
"filter_fn",
lambda m, fqn: (
isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear)
),
)

config = AOQuantizationConfig(
ao_base_config=CodebookWeightOnlyConfig(
dtype=dtype,
block_size=block_size,
),
filter_fn=filter_fn,
)

quantization_recipe = QuantizationRecipe(
quantizers=None,
ao_quantization_configs=[config],
)

lowering_recipe = self._get_coreml_lowering_recipe(**kwargs)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quantization_recipe,
lowering_recipe=lowering_recipe,
)

def _get_coreml_lowering_recipe(
self,
compute_precision: ct.precision,
compute_precision: ct.precision = ct.precision.FLOAT16,
**kwargs: Any,
) -> LoweringRecipe:
"""Get CoreML lowering recipe with optional precision"""
compile_specs = CoreMLBackend.generate_compile_specs(
compute_precision=compute_precision,
**kwargs,
compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL),
minimum_deployment_target=kwargs.get("minimum_deployment_target", None),
)

minimum_deployment_target = kwargs.get("minimum_deployment_target", None)
Expand Down
Loading
Loading