Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/xnnpack/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ runtime.python_library(
],
deps = [
":xnnpack_preprocess",
"//executorch/export:lib",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_provider",
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_types",
],
)
9 changes: 8 additions & 1 deletion backends/xnnpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.export import recipe_registry

# Exposed Partitioners in XNNPACK Package
from .partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
XnnpackPartitioner,
)
from .recipes.xnnpack_recipe_provider import XNNPACKRecipeProvider
from .recipes.xnnpack_recipe_types import XNNPackRecipeType

# Auto-register XNNPACK recipe provider
recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider())

# Exposed Configs in XNNPACK Package
from .utils.configs import (
Expand All @@ -23,11 +30,11 @@
# XNNPACK Backend
from .xnnpack_preprocess import XnnpackBackend


__all__ = [
"XnnpackDynamicallyQuantizedPartitioner",
"XnnpackPartitioner",
"XnnpackBackend",
"XNNPackRecipeType",
"capture_graph_for_xnnpack",
"get_xnnpack_capture_config",
"get_xnnpack_edge_compile_config",
Expand Down
35 changes: 35 additions & 0 deletions backends/xnnpack/recipes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "xnnpack_recipe_provider",
srcs = [
"xnnpack_recipe_provider.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
"//executorch/export:lib",
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
":xnnpack_recipe_types",
],
)

runtime.python_library(
name = "xnnpack_recipe_types",
srcs = [
"xnnpack_recipe_types.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/export:lib",
],
)
184 changes: 184 additions & 0 deletions backends/xnnpack/recipes/xnnpack_recipe_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Optional, Sequence

import torch

from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)

from executorch.backends.xnnpack.recipes.xnnpack_recipe_types import XNNPackRecipeType
from executorch.backends.xnnpack.utils.configs import (
get_xnnpack_edge_compile_config,
get_xnnpack_executorch_backend_config,
)
from executorch.export import (
BackendRecipeProvider,
ExportRecipe,
QuantizationRecipe,
RecipeType,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig


class XNNPACKRecipeProvider(BackendRecipeProvider):
@property
def backend_name(self) -> str:
return "xnnpack"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might already have this somewhere. Just want to make sure we are consistent.


def get_supported_recipes(self) -> Sequence[RecipeType]:
return list(XNNPackRecipeType)

def create_recipe(
self, recipe_type: RecipeType, **kwargs: Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to use kwargs? It makes it hard to read and easier to slip on typing checks or other validations.

) -> Optional[ExportRecipe]:
"""Create XNNPACK recipe"""

if recipe_type not in self.get_supported_recipes():
return None

# Validate kwargs
self._validate_recipe_kwargs(recipe_type, **kwargs)

if recipe_type == XNNPackRecipeType.FP32:
return self._build_fp32_recipe(recipe_type)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit using a match stmt might be cleaner?

return self._build_quantized_recipe(
recipe_type, is_per_channel=True, is_dynamic=True
)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_TENSOR:
return self._build_quantized_recipe(
recipe_type, is_per_channel=False, is_dynamic=True
)

elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL:
return self._build_quantized_recipe(
recipe_type, is_per_channel=True, is_dynamic=False
)

elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR:
return self._build_quantized_recipe(
recipe_type, is_per_channel=False, is_dynamic=False
)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL:
return self._build_int8da_intx_weight_recipe(
recipe_type=recipe_type,
is_per_channel=True,
weight_dtype=torch.int4,
)

elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
group_size = kwargs.get("group_size", 32)
return self._build_int8da_intx_weight_recipe(
recipe_type=recipe_type,
is_per_channel=False,
weight_dtype=torch.int4,
group_size=group_size,
)
return None

def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
return ExportRecipe(
name=recipe_type.value,
edge_compile_config=get_xnnpack_edge_compile_config(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _build_quantized_recipe(
self,
recipe_type: RecipeType,
is_per_channel: bool = True,
is_dynamic: bool = True,
is_qat: bool = False,
) -> ExportRecipe:
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat
)
quantizer.set_global(operator_config)

quant_recipe = QuantizationRecipe(quantizers=[quantizer])

precision_type = (
ConfigPrecisionType.DYNAMIC_QUANT
if is_dynamic
else ConfigPrecisionType.STATIC_QUANT
)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
)

def _build_int8da_intx_weight_recipe(
self,
recipe_type: RecipeType,
is_per_channel: bool = True,
weight_dtype: torch.dtype = torch.int4,
group_size: int = 32,
) -> ExportRecipe:
if is_per_channel:
weight_granularity = PerAxis(axis=0)
else:
weight_granularity = PerGroup(group_size=group_size)

config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=weight_granularity,
)

quant_recipe = QuantizationRecipe(
quantizers=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this none?

ao_base_config=[config],
)

return ExportRecipe(
name=recipe_type.value,
quantization_recipe=quant_recipe,
edge_compile_config=get_xnnpack_edge_compile_config(),
executorch_backend_config=get_xnnpack_executorch_backend_config(),
partitioners=[XnnpackPartitioner()],
)

def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
expected_keys = {"group_size"}
unexpected = set(kwargs.keys()) - expected_keys
if unexpected:
raise ValueError(
f"Recipe '{recipe_type.value}' only accepts 'group_size' parameter. "
f"Unexpected parameters: {list(unexpected)}"
)
if "group_size" in kwargs:
group_size = kwargs["group_size"]
if not isinstance(group_size, int):
raise ValueError(
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
)
elif kwargs:
# All other recipes don't expect any kwargs
unexpected = list(kwargs.keys())
raise ValueError(
f"Recipe '{recipe_type.value}' does not accept any parameters. "
f"Unexpected parameters: {unexpected}"
)
33 changes: 33 additions & 0 deletions backends/xnnpack/recipes/xnnpack_recipe_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from executorch.export import RecipeType


class XNNPackRecipeType(RecipeType):
"""XNNPACK-specific recipe types"""

FP32 = "fp32"
# INT8 Dynamic Quantization
INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel"
INT8_DYNAMIC_PER_TENSOR = "int8_dynamic_per_tensor"
# INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel"
# INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32
# can be overriden by group_size kwarg
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor"
# INT8 Static Activations INT4 Weight Quantization
INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel"
INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor"
# INT8 Static Quantization, needs calibration dataset
INT8_STATIC_PER_CHANNEL = "int8_static_per_channel"
INT8_STATIC_PER_TENSOR = "int8_static_per_tensor"

@classmethod
def get_backend_name(cls) -> str:
return "xnnpack"
15 changes: 15 additions & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ runtime.python_test(
"libtorch",
],
)

runtime.python_test(
name = "test_xnnpack_recipes",
srcs = glob([
"recipes/*.py",
]),
deps = [
"//executorch/backends/xnnpack:xnnpack_delegate",
"//executorch/export:lib",
"//pytorch/vision:torchvision", # @manual
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/examples/models:models", # @manual
"//executorch/examples/xnnpack:models", # @manual
],
)
Loading
Loading