Skip to content

Commit a711af7

Browse files
Add export recipes for xnnpack (#12069) (#12070)
Summary: Enables basic export recipes for XNNPack backend described in #12069 Differential Revision: D77414795
1 parent 45b8f35 commit a711af7

File tree

11 files changed

+552
-8
lines changed

11 files changed

+552
-8
lines changed

backends/xnnpack/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ runtime.python_library(
3636
],
3737
deps = [
3838
":xnnpack_preprocess",
39+
"//executorch/export:lib",
3940
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
4041
"//executorch/backends/xnnpack/utils:xnnpack_utils",
42+
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_provider",
43+
"//executorch/backends/xnnpack/recipes:xnnpack_recipe_types",
4144
],
4245
)

backends/xnnpack/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from executorch.export import recipe_registry
8+
79
# Exposed Partitioners in XNNPACK Package
810
from .partition.xnnpack_partitioner import (
911
XnnpackDynamicallyQuantizedPartitioner,
1012
XnnpackPartitioner,
1113
)
14+
from .recipes.xnnpack_recipe_provider import XNNPACKRecipeProvider
15+
from .recipes.xnnpack_recipe_types import XNNPackRecipeType
16+
17+
# Auto-register XNNPACK recipe provider
18+
recipe_registry.register_backend_recipe_provider(XNNPACKRecipeProvider())
1219

1320
# Exposed Configs in XNNPACK Package
1421
from .utils.configs import (
@@ -23,11 +30,11 @@
2330
# XNNPACK Backend
2431
from .xnnpack_preprocess import XnnpackBackend
2532

26-
2733
__all__ = [
2834
"XnnpackDynamicallyQuantizedPartitioner",
2935
"XnnpackPartitioner",
3036
"XnnpackBackend",
37+
"XNNPackRecipeType",
3138
"capture_graph_for_xnnpack",
3239
"get_xnnpack_capture_config",
3340
"get_xnnpack_edge_compile_config",

backends/xnnpack/recipes/TARGETS

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "xnnpack_recipe_provider",
7+
srcs = [
8+
"xnnpack_recipe_provider.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
"@EXECUTORCH_CLIENTS",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/export:lib",
17+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
18+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
19+
":xnnpack_recipe_types",
20+
],
21+
)
22+
23+
runtime.python_library(
24+
name = "xnnpack_recipe_types",
25+
srcs = [
26+
"xnnpack_recipe_types.py",
27+
],
28+
visibility = [
29+
"//executorch/...",
30+
"@EXECUTORCH_CLIENTS",
31+
],
32+
deps = [
33+
"//executorch/export:lib",
34+
],
35+
)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Any, Optional, Sequence
10+
11+
import torch
12+
13+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
14+
ConfigPrecisionType,
15+
)
16+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
18+
get_symmetric_quantization_config,
19+
XNNPACKQuantizer,
20+
)
21+
22+
from executorch.backends.xnnpack.recipes.xnnpack_recipe_types import XNNPackRecipeType
23+
from executorch.backends.xnnpack.utils.configs import (
24+
get_xnnpack_edge_compile_config,
25+
get_xnnpack_executorch_backend_config,
26+
)
27+
from executorch.export import (
28+
BackendRecipeProvider,
29+
ExportRecipe,
30+
QuantizationRecipe,
31+
RecipeType,
32+
)
33+
from torchao.quantization.granularity import PerAxis, PerGroup
34+
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
35+
36+
37+
class XNNPACKRecipeProvider(BackendRecipeProvider):
38+
@property
39+
def backend_name(self) -> str:
40+
return "xnnpack"
41+
42+
def get_supported_recipes(self) -> Sequence[RecipeType]:
43+
return list(XNNPackRecipeType)
44+
45+
def create_recipe(
46+
self, recipe_type: RecipeType, **kwargs: Any
47+
) -> Optional[ExportRecipe]:
48+
"""Create XNNPACK recipe"""
49+
50+
if recipe_type not in self.get_supported_recipes():
51+
return None
52+
53+
# Validate kwargs
54+
self._validate_recipe_kwargs(recipe_type, **kwargs)
55+
56+
if recipe_type == XNNPackRecipeType.FP32:
57+
return self._build_fp32_recipe(recipe_type)
58+
59+
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL:
60+
return self._build_quantized_recipe(
61+
recipe_type, is_per_channel=True, is_dynamic=True
62+
)
63+
64+
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_TENSOR:
65+
return self._build_quantized_recipe(
66+
recipe_type, is_per_channel=False, is_dynamic=True
67+
)
68+
69+
elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL:
70+
return self._build_quantized_recipe(
71+
recipe_type, is_per_channel=True, is_dynamic=False
72+
)
73+
74+
elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR:
75+
return self._build_quantized_recipe(
76+
recipe_type, is_per_channel=False, is_dynamic=False
77+
)
78+
79+
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL:
80+
return self._build_int8da_intx_weight_recipe(
81+
recipe_type=recipe_type,
82+
is_per_channel=True,
83+
weight_dtype=torch.int4,
84+
)
85+
86+
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
87+
group_size = kwargs.get("group_size", 32)
88+
return self._build_int8da_intx_weight_recipe(
89+
recipe_type=recipe_type,
90+
is_per_channel=False,
91+
weight_dtype=torch.int4,
92+
group_size=group_size,
93+
)
94+
return None
95+
96+
def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
97+
return ExportRecipe(
98+
name=recipe_type.value,
99+
edge_compile_config=get_xnnpack_edge_compile_config(),
100+
executorch_backend_config=get_xnnpack_executorch_backend_config(),
101+
partitioners=[XnnpackPartitioner()],
102+
)
103+
104+
def _build_quantized_recipe(
105+
self,
106+
recipe_type: RecipeType,
107+
is_per_channel: bool = True,
108+
is_dynamic: bool = True,
109+
is_qat: bool = False,
110+
) -> ExportRecipe:
111+
quantizer = XNNPACKQuantizer()
112+
operator_config = get_symmetric_quantization_config(
113+
is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat
114+
)
115+
quantizer.set_global(operator_config)
116+
117+
quant_recipe = QuantizationRecipe(quantizers=[quantizer])
118+
119+
precision_type = (
120+
ConfigPrecisionType.DYNAMIC_QUANT
121+
if is_dynamic
122+
else ConfigPrecisionType.STATIC_QUANT
123+
)
124+
125+
return ExportRecipe(
126+
name=recipe_type.value,
127+
quantization_recipe=quant_recipe,
128+
edge_compile_config=get_xnnpack_edge_compile_config(),
129+
executorch_backend_config=get_xnnpack_executorch_backend_config(),
130+
partitioners=[XnnpackPartitioner(config_precision=precision_type)],
131+
)
132+
133+
def _build_int8da_intx_weight_recipe(
134+
self,
135+
recipe_type: RecipeType,
136+
is_per_channel: bool = True,
137+
weight_dtype: torch.dtype = torch.int4,
138+
group_size: int = 32,
139+
) -> ExportRecipe:
140+
if is_per_channel:
141+
weight_granularity = PerAxis(axis=0)
142+
else:
143+
weight_granularity = PerGroup(group_size=group_size)
144+
145+
config = Int8DynamicActivationIntxWeightConfig(
146+
weight_dtype=weight_dtype,
147+
weight_granularity=weight_granularity,
148+
)
149+
150+
quant_recipe = QuantizationRecipe(
151+
quantizers=None,
152+
ao_base_config=[config],
153+
)
154+
155+
return ExportRecipe(
156+
name=recipe_type.value,
157+
quantization_recipe=quant_recipe,
158+
edge_compile_config=get_xnnpack_edge_compile_config(),
159+
executorch_backend_config=get_xnnpack_executorch_backend_config(),
160+
partitioners=[XnnpackPartitioner()],
161+
)
162+
163+
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
164+
if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
165+
expected_keys = {"group_size"}
166+
unexpected = set(kwargs.keys()) - expected_keys
167+
if unexpected:
168+
raise ValueError(
169+
f"Recipe '{recipe_type.value}' only accepts 'group_size' parameter. "
170+
f"Unexpected parameters: {list(unexpected)}"
171+
)
172+
if "group_size" in kwargs:
173+
group_size = kwargs["group_size"]
174+
if not isinstance(group_size, int):
175+
raise ValueError(
176+
f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}"
177+
)
178+
elif kwargs:
179+
# All other recipes don't expect any kwargs
180+
unexpected = list(kwargs.keys())
181+
raise ValueError(
182+
f"Recipe '{recipe_type.value}' does not accept any parameters. "
183+
f"Unexpected parameters: {unexpected}"
184+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from executorch.export import RecipeType
10+
11+
12+
class XNNPackRecipeType(RecipeType):
13+
"""XNNPACK-specific recipe types"""
14+
15+
FP32 = "fp32"
16+
# INT8 Dynamic Quantization
17+
INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel"
18+
INT8_DYNAMIC_PER_TENSOR = "int8_dynamic_per_tensor"
19+
# INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0
20+
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel"
21+
# INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32
22+
# can be overriden by group_size kwarg
23+
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor"
24+
# INT8 Static Activations INT4 Weight Quantization
25+
INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel"
26+
INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor"
27+
# INT8 Static Quantization, needs calibration dataset
28+
INT8_STATIC_PER_CHANNEL = "int8_static_per_channel"
29+
INT8_STATIC_PER_TENSOR = "int8_static_per_tensor"
30+
31+
@classmethod
32+
def get_backend_name(cls) -> str:
33+
return "xnnpack"

backends/xnnpack/test/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,18 @@ runtime.python_test(
9494
"libtorch",
9595
],
9696
)
97+
98+
runtime.python_test(
99+
name = "test_xnnpack_recipes",
100+
srcs = glob([
101+
"recipes/*.py",
102+
]),
103+
deps = [
104+
"//executorch/backends/xnnpack:xnnpack_delegate",
105+
"//executorch/export:lib",
106+
"//pytorch/vision:torchvision", # @manual
107+
"//executorch/backends/xnnpack/test/tester:tester",
108+
"//executorch/examples/models:models", # @manual
109+
"//executorch/examples/xnnpack:models", # @manual
110+
],
111+
)

0 commit comments

Comments
 (0)