Skip to content

Commit 3290db4

Browse files
Add export recipes for xnnpack (#12069) (#12070)
Summary: Pull Request resolved: #12070 Enables basic export recipes for XNNPack backend described in #12069 Adds five recipes: a. static per channel quant b. static per tensor quant c. dynamic per channel quant d. fp32 e. 8a4w Differential Revision: D77414795
1 parent 02454eb commit 3290db4

File tree

9 files changed

+334
-19
lines changed

9 files changed

+334
-19
lines changed

backends/xnnpack/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ runtime.python_library(
3838
":xnnpack_preprocess",
3939
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
4040
"//executorch/backends/xnnpack/utils:xnnpack_utils",
41+
"//executorch/backends/xnnpack/recipes:xnnpack_recipes"
4142
],
4243
)

backends/xnnpack/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
XnnpackDynamicallyQuantizedPartitioner,
1010
XnnpackPartitioner,
1111
)
12+
from .recipes.recipes import get_xnnpack_recipe
1213

1314
# Exposed Configs in XNNPACK Package
1415
from .utils.configs import (
@@ -23,12 +24,12 @@
2324
# XNNPACK Backend
2425
from .xnnpack_preprocess import XnnpackBackend
2526

26-
2727
__all__ = [
2828
"XnnpackDynamicallyQuantizedPartitioner",
2929
"XnnpackPartitioner",
3030
"XnnpackBackend",
3131
"capture_graph_for_xnnpack",
32+
"get_xnnpack_recipe",
3233
"get_xnnpack_capture_config",
3334
"get_xnnpack_edge_compile_config",
3435
"get_xnnpack_executorch_backend_config",

backends/xnnpack/recipes/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "xnnpack_recipes",
7+
srcs = [
8+
"recipes.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
"@EXECUTORCH_CLIENTS",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/exir:lib",
17+
"//executorch/export:recipe",
18+
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
19+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
20+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
21+
],
22+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from functools import partial
9+
from typing import Any, Callable
10+
11+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
12+
ConfigPrecisionType,
13+
)
14+
15+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
16+
17+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
18+
get_symmetric_quantization_config,
19+
XNNPACKQuantizer,
20+
)
21+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
22+
from executorch.export.recipe import ExportRecipe, QuantizationRecipe
23+
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight
24+
25+
26+
def get_fp32_recipe() -> ExportRecipe:
27+
return ExportRecipe(
28+
name="fp32",
29+
quantization_recipe=None,
30+
partitioners=[XnnpackPartitioner()],
31+
)
32+
33+
34+
def get_quant_recipe(quant_recipe_name: str, is_per_channel: bool, is_dynamic: bool, is_qat:bool=False, **_kwargs: Any) -> ExportRecipe:
35+
# Create quantizer
36+
quantizer = XNNPACKQuantizer()
37+
operator_config = get_symmetric_quantization_config(
38+
is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat
39+
)
40+
quantizer.set_global(operator_config)
41+
42+
# Create quantization recipe
43+
quant_recipe = QuantizationRecipe(
44+
quantizers=[quantizer],
45+
)
46+
47+
config_precision = (ConfigPrecisionType.DYNAMIC_QUANT if is_dynamic else ConfigPrecisionType.STATIC_QUANT)
48+
49+
# Create export recipe
50+
return ExportRecipe(
51+
name=quant_recipe_name,
52+
quantization_recipe=quant_recipe,
53+
partitioners=[XnnpackPartitioner(config_precision=config_precision)],
54+
edge_compile_config=get_xnnpack_edge_compile_config(),
55+
)
56+
57+
58+
def get_8a4w_config(group_size: int = 32) -> ExportRecipe:
59+
# Create quantization recipe
60+
quant_recipe = QuantizationRecipe(
61+
quantizers=None,
62+
ao_base_config=[
63+
int8_dynamic_activation_int4_weight(group_size=group_size),
64+
],
65+
)
66+
67+
# Create export recipe
68+
return ExportRecipe(
69+
name="8a4w_quant",
70+
quantization_recipe=quant_recipe,
71+
partitioners=[XnnpackPartitioner()],
72+
)
73+
74+
75+
RECIPE_MAP: dict[str, Callable[..., ExportRecipe]] = {
76+
"FP32_RECIPE": get_fp32_recipe,
77+
"QUANT_RECIPE": get_quant_recipe,
78+
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "dynamic_per_channel_quant", is_per_channel=True, is_dynamic=True),
79+
"STATIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "static_per_channel_quant", is_per_channel=True, is_dynamic=False),
80+
"STATIC_PER_TENSOR_QUANT_RECIPE": partial(get_quant_recipe, "static_per_tensor_quant",is_per_channel=False, is_dynamic=False),
81+
"8A4W_ACCELERATED_RECIPE": get_8a4w_config,
82+
}
83+
84+
85+
def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe:
86+
assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found."
87+
return RECIPE_MAP[recipe_name](**kwargs)

backends/xnnpack/test/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,16 @@ runtime.python_test(
9494
"libtorch",
9595
],
9696
)
97+
98+
runtime.python_test(
99+
name = "test_xnnpack_recipes",
100+
srcs = glob([
101+
"recipes/*.py",
102+
]),
103+
deps = [
104+
"//executorch/backends/xnnpack:xnnpack_delegate",
105+
"//executorch/export:lib",
106+
"//pytorch/vision:torchvision", # @manual,
107+
"//executorch/backends/xnnpack/test/tester:tester",
108+
],
109+
)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
from executorch.backends.xnnpack import get_xnnpack_recipe
13+
from executorch.exir.schema import DelegateCall, Program
14+
from executorch.export import export
15+
from torch import nn
16+
from torch.testing._internal.common_quantization import TestHelperModules
17+
from torchvision import models
18+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
19+
from executorch.backends.xnnpack.test.tester import Tester
20+
from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual
21+
22+
23+
class TestXnnpackRecipes(unittest.TestCase):
24+
def setUp(self) -> None:
25+
torch._dynamo.reset()
26+
super().setUp()
27+
28+
def tearDown(self) -> None:
29+
super().tearDown()
30+
31+
def check_fully_delegated(self, program: Program) -> None:
32+
instructions = program.execution_plan[0].chains[0].instructions
33+
assert instructions is not None
34+
self.assertEqual(len(instructions), 1)
35+
self.assertIsInstance(instructions[0].instr_args, DelegateCall)
36+
37+
def test_basic_recipe(self) -> None:
38+
m_eager = TestHelperModules.TwoLinearModule().eval()
39+
example_inputs = [(torch.randn(9, 8),)]
40+
session = export(
41+
model=m_eager,
42+
example_inputs=example_inputs,
43+
export_recipe=get_xnnpack_recipe("FP32_RECIPE"),
44+
)
45+
self.assertTrue(
46+
torch.allclose(
47+
session.run_method("forward", example_inputs[0])[0],
48+
m_eager(*example_inputs[0]),
49+
atol=1e-1,
50+
)
51+
)
52+
self.check_fully_delegated(session.get_executorch_program())
53+
54+
def test_dynamic_quant_recipe(self) -> None:
55+
with torch.no_grad():
56+
m_eager = TestHelperModules.TwoLinearModule().eval()
57+
example_inputs = [(torch.randn(9, 8),)]
58+
session = export(
59+
model=m_eager,
60+
example_inputs=example_inputs,
61+
export_recipe=get_xnnpack_recipe(
62+
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE"
63+
),
64+
)
65+
self.assertTrue(
66+
torch.allclose(
67+
session.run_method("forward", example_inputs[0])[0],
68+
m_eager(*example_inputs[0]),
69+
atol=1e-1,
70+
)
71+
)
72+
self.check_fully_delegated(session.get_executorch_program())
73+
74+
def test_static_quant_recipe(self) -> None:
75+
with torch.no_grad():
76+
m_eager = TestHelperModules.TwoLinearModule().eval()
77+
example_inputs = [(torch.randn(9, 8),)]
78+
session = export(
79+
model=m_eager,
80+
example_inputs=example_inputs,
81+
export_recipe=get_xnnpack_recipe(
82+
"STATIC_PER_CHANNEL_QUANT_RECIPE"
83+
),
84+
)
85+
self.assertTrue(
86+
torch.allclose(
87+
session.run_method("forward", example_inputs[0])[0],
88+
m_eager(*example_inputs[0]),
89+
atol=1e-1,
90+
)
91+
)
92+
self.check_fully_delegated(session.get_executorch_program())
93+
94+
def test_8a4w_recipe(self) -> None:
95+
class SimpleLinearModel(nn.Module):
96+
def __init__(self) -> None:
97+
super(SimpleLinearModel, self).__init__()
98+
self.layer1 = nn.Linear(32, 2)
99+
100+
def forward(self, x) -> torch.Tensor:
101+
x = self.layer1(x)
102+
return x
103+
104+
model = SimpleLinearModel()
105+
example_inputs = [(torch.randn(1, 32),)]
106+
session = export(
107+
model=model,
108+
example_inputs=example_inputs,
109+
export_recipe=get_xnnpack_recipe(
110+
"8A4W_ACCELERATED_RECIPE", group_size=32
111+
),
112+
)
113+
self.assertTrue(
114+
torch.allclose(
115+
session.run_method("forward", example_inputs[0])[0],
116+
model(*example_inputs[0]),
117+
atol=1e-1,
118+
)
119+
)
120+
self.check_fully_delegated(session.get_executorch_program())
121+
122+
def test_mv3_model(self) -> None:
123+
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
124+
mv3 = mv3.eval()
125+
model_inputs = [(torch.randn(1, 3, 224, 224),)]
126+
self.assertTrue(hasattr(mv3, "forward"))
127+
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},)
128+
session = export(
129+
model=mv3,
130+
example_inputs=model_inputs,
131+
dynamic_shapes=dynamic_shapes,
132+
export_recipe=get_xnnpack_recipe(
133+
"STATIC_PER_CHANNEL_QUANT_RECIPE"
134+
),
135+
)
136+
137+
Tester._assert_outputs_equal(
138+
session.run_method("forward", model_inputs[0])[0],
139+
mv3(*model_inputs[0]),
140+
atol=1e-3,
141+
)
142+
143+
def test_mv2_model_with_static_quant_recipe(self) -> None:
144+
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
145+
mv2 = mv2.eval()
146+
model_inputs = [(torch.randn(1, 3, 224, 224),)]
147+
self.assertTrue(hasattr(mv2, "forward"))
148+
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},)
149+
session = export(
150+
model=mv2,
151+
example_inputs=model_inputs,
152+
dynamic_shapes=dynamic_shapes,
153+
export_recipe=get_xnnpack_recipe(
154+
"STATIC_PER_CHANNEL_QUANT_RECIPE"
155+
),
156+
)
157+
158+
Tester._assert_outputs_equal(
159+
session.run_method("forward", model_inputs[0])[0],
160+
mv2(*model_inputs[0]),
161+
atol=1e-3,
162+
)
163+
164+
def test_dl3_with_recipe(self) -> None:
165+
class DL3Wrapper(torch.nn.Module):
166+
def __init__(self):
167+
super().__init__()
168+
self.m = deeplabv3_resnet50(
169+
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
170+
)
171+
172+
def forward(self, *args):
173+
return self.m(*args)["out"]
174+
175+
dl3 = DL3Wrapper()
176+
dl3 = dl3.eval()
177+
model_inputs = [(torch.randn(1, 3, 224, 224),)]
178+
self.assertTrue(hasattr(dl3, "forward"))
179+
session = export(
180+
model=dl3,
181+
example_inputs=model_inputs,
182+
export_recipe=get_xnnpack_recipe(
183+
"STATIC_PER_CHANNEL_QUANT_RECIPE"
184+
),
185+
)
186+
187+
Tester._assert_outputs_equal(
188+
session.run_method("forward", model_inputs[0])[0],
189+
dl3(*model_inputs[0]),
190+
atol=1e-3,
191+
)
192+

export/export.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
from .recipe import ExportRecipe
2727

28+
from torch._export.pass_base import PassType
29+
from executorch.exir.program._program import _transform
30+
2831

2932
class Stage(ABC):
3033
"""
@@ -95,9 +98,7 @@ class ExportStage(Stage):
9598

9699
def __init__(
97100
self,
98-
pre_edge_transform_passes: Optional[
99-
Callable[[ExportedProgram], ExportedProgram]
100-
] = None,
101+
pre_edge_transform_passes: Optional[List[PassType]] = None,
101102
) -> None:
102103
self._exported_program: Dict[str, ExportedProgram] = {}
103104
self._pre_edge_transform_passes = pre_edge_transform_passes
@@ -153,10 +154,10 @@ def run(
153154
)
154155

155156
# Apply pre-edge transform passes if available
156-
if self._pre_edge_transform_passes is not None:
157-
for pre_edge_transform_pass in self._pre_edge_transform_passes:
158-
self._exported_program[method_name] = pre_edge_transform_pass(
159-
self._exported_program[method_name]
157+
if pre_edge_transform_passes:= self._pre_edge_transform_passes or []:
158+
for pass_ in pre_edge_transform_passes:
159+
self._exported_program[method_name] = _transform(
160+
self._exported_program[method_name], pass_
160161
)
161162

162163
def get_artifacts(self) -> Dict[str, ExportedProgram]:

0 commit comments

Comments
 (0)