Skip to content

Commit 757f704

Browse files
Add TorchAO wrapper config to allow filter_fn for quantize_
ghstack-source-id: 1c9b691 ghstack-comment-id: 3172341537 Pull-Request: #13264
1 parent 76a4062 commit 757f704

File tree

8 files changed

+232
-84
lines changed

8 files changed

+232
-84
lines changed

backends/xnnpack/recipes/xnnpack_recipe_provider.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_xnnpack_executorch_backend_config,
2626
)
2727
from executorch.export import (
28+
AOQuantizationConfig,
2829
BackendRecipeProvider,
2930
ExportRecipe,
3031
LoweringRecipe,
@@ -57,31 +58,37 @@ def create_recipe(
5758
if recipe_type == XNNPackRecipeType.FP32:
5859
return self._build_fp32_recipe(recipe_type)
5960

60-
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL:
61+
elif recipe_type == XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL:
6162
return self._build_quantized_recipe(
6263
recipe_type, is_per_channel=True, is_dynamic=True
6364
)
6465

65-
elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL:
66+
elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL:
6667
return self._build_quantized_recipe(
6768
recipe_type, is_per_channel=True, is_dynamic=False
6869
)
6970

70-
elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR:
71+
elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR:
7172
return self._build_quantized_recipe(
7273
recipe_type, is_per_channel=False, is_dynamic=False
7374
)
7475

75-
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL:
76-
return self._build_int8da_intx_weight_recipe(
76+
elif (
77+
recipe_type
78+
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL
79+
):
80+
return self._build_torchao_quantized_recipe(
7781
recipe_type=recipe_type,
7882
is_per_channel=True,
7983
weight_dtype=torch.int4,
8084
)
8185

82-
elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
86+
elif (
87+
recipe_type
88+
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
89+
):
8390
group_size = kwargs.get("group_size", 32)
84-
return self._build_int8da_intx_weight_recipe(
91+
return self._build_torchao_quantized_recipe(
8592
recipe_type=recipe_type,
8693
is_per_channel=False,
8794
weight_dtype=torch.int4,
@@ -132,7 +139,7 @@ def _build_quantized_recipe(
132139
executorch_backend_config=get_xnnpack_executorch_backend_config(),
133140
)
134141

135-
def _build_int8da_intx_weight_recipe(
142+
def _build_torchao_quantized_recipe(
136143
self,
137144
recipe_type: RecipeType,
138145
is_per_channel: bool = True,
@@ -141,17 +148,21 @@ def _build_int8da_intx_weight_recipe(
141148
) -> ExportRecipe:
142149
if is_per_channel:
143150
weight_granularity = PerAxis(axis=0)
151+
assert weight_dtype == torch.int4 or weight_dtype == torch.int8
144152
else:
145153
weight_granularity = PerGroup(group_size=group_size)
154+
assert weight_dtype == torch.int4
146155

147-
config = Int8DynamicActivationIntxWeightConfig(
148-
weight_dtype=weight_dtype,
149-
weight_granularity=weight_granularity,
156+
config = AOQuantizationConfig(
157+
Int8DynamicActivationIntxWeightConfig(
158+
weight_dtype=weight_dtype,
159+
weight_granularity=weight_granularity,
160+
)
150161
)
151162

152163
quant_recipe = QuantizationRecipe(
153164
quantizers=None,
154-
ao_base_config=[config],
165+
ao_quantization_configs=[config],
155166
)
156167

157168
return ExportRecipe(
@@ -162,7 +173,10 @@ def _build_int8da_intx_weight_recipe(
162173
)
163174

164175
def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None:
165-
if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR:
176+
if (
177+
recipe_type
178+
== XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
179+
):
166180
expected_keys = {"group_size"}
167181
unexpected = set(kwargs.keys()) - expected_keys
168182
if unexpected:

backends/xnnpack/recipes/xnnpack_recipe_types.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@ class XNNPackRecipeType(RecipeType):
1313
"""XNNPACK-specific recipe types"""
1414

1515
FP32 = "fp32"
16+
17+
## PT2E-based quantization recipes
1618
# INT8 Dynamic Quantization
17-
INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel"
19+
PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel"
20+
# INT8 Static Quantization, needs calibration dataset
21+
PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel"
22+
PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor"
23+
24+
## TorchAO-based quantization recipes
1825
# INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0
19-
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel"
26+
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = (
27+
"torchao_int8da_int4w_per_channel"
28+
)
2029
# INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32
2130
# can be overriden by group_size kwarg
22-
INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor"
23-
# INT8 Static Activations INT4 Weight Quantization
24-
INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel"
25-
INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor"
26-
# INT8 Static Quantization, needs calibration dataset
27-
INT8_STATIC_PER_CHANNEL = "int8_static_per_channel"
28-
INT8_STATIC_PER_TENSOR = "int8_static_per_tensor"
31+
TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor"
2932

3033
@classmethod
3134
def get_backend_name(cls) -> str:

backends/xnnpack/test/recipes/test_xnnpack_recipes.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
2020
from executorch.exir.schema import DelegateCall, Program
2121
from executorch.export import export, ExportRecipe, recipe_registry
22+
from export.types import StageType
2223
from torch import nn
2324
from torch.testing._internal.common_quantization import TestHelperModules
25+
from torchao.quantization.utils import compute_error
2426

2527

2628
class TestXnnpackRecipes(unittest.TestCase):
@@ -38,6 +40,28 @@ def check_fully_delegated(self, program: Program) -> None:
3840
self.assertEqual(len(instructions), 1)
3941
self.assertIsInstance(instructions[0].instr_args, DelegateCall)
4042

43+
# pyre-ignore
44+
def _compare_eager_quantized_model_outputs(
45+
self, session, example_inputs, atol: float
46+
) -> None:
47+
"""Utility to compare eager quantized model output with session output after xnnpack lowering"""
48+
torch_export_stage_output = session.get_stage_artifacts()[
49+
StageType.TORCH_EXPORT
50+
]
51+
eager_quantized_model = torch_export_stage_output.data["forward"].module()
52+
output = session.run_method("forward", example_inputs[0])[0]
53+
expected = eager_quantized_model(*example_inputs[0])
54+
Tester._assert_outputs_equal(output, expected, atol=atol)
55+
56+
def _compare_eager_unquantized_model_outputs(
57+
self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20
58+
):
59+
"""Utility to compare eager unquantized model output with session output using SQNR"""
60+
quantized_output = session.run_method("forward", example_inputs[0])[0]
61+
original_output = eager_unquantized_model(*example_inputs[0])
62+
error = compute_error(original_output, quantized_output)
63+
self.assertTrue(error > sqnr_threshold)
64+
4165
def test_basic_recipe(self) -> None:
4266
m_eager = TestHelperModules.TwoLinearModule().eval()
4367
example_inputs = [(torch.randn(9, 8),)]
@@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None:
4670
example_inputs=example_inputs,
4771
export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32),
4872
)
49-
self.assertTrue(
50-
torch.allclose(
51-
session.run_method("forward", example_inputs[0])[0],
52-
m_eager(*example_inputs[0]),
53-
atol=1e-3,
54-
)
55-
)
73+
self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3)
5674
self.check_fully_delegated(session.get_executorch_program())
75+
self._compare_eager_unquantized_model_outputs(session, m_eager, example_inputs)
5776

5877
def test_int8_dynamic_quant_recipe(self) -> None:
5978
test_cases = [
60-
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL),
79+
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL),
6180
]
6281

6382
for export_recipe in test_cases:
@@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None:
7089
example_inputs=example_inputs,
7190
export_recipe=export_recipe,
7291
)
73-
self.assertTrue(
74-
torch.allclose(
75-
session.run_method("forward", example_inputs[0])[0],
76-
m_eager(*example_inputs[0]),
77-
atol=1e-1,
78-
)
92+
self._compare_eager_quantized_model_outputs(
93+
session, example_inputs, 1e-2
7994
)
8095
self.check_fully_delegated(session.get_executorch_program())
96+
self._compare_eager_unquantized_model_outputs(
97+
session, m_eager, example_inputs
98+
)
8199

82100
def test_int8_static_quant_recipe(self) -> None:
83101
test_cases = [
84-
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL),
85-
ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR),
102+
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL),
103+
ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR),
86104
]
87105

88106
for export_recipe in test_cases:
@@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None:
95113
example_inputs=example_inputs,
96114
export_recipe=export_recipe,
97115
)
98-
self.assertTrue(
99-
torch.allclose(
100-
session.run_method("forward", example_inputs[0])[0],
101-
m_eager(*example_inputs[0]),
102-
atol=1e-1,
103-
)
116+
self._compare_eager_quantized_model_outputs(
117+
session, example_inputs, 1e-3
104118
)
105119
self.check_fully_delegated(session.get_executorch_program())
120+
self._compare_eager_unquantized_model_outputs(
121+
session, m_eager, example_inputs
122+
)
106123

107124
def test_8a4w_recipe(self) -> None:
108125
class SimpleLinearModel(nn.Module):
@@ -116,10 +133,10 @@ def forward(self, x) -> torch.Tensor:
116133

117134
test_cases = [
118135
ExportRecipe.get_recipe(
119-
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL,
136+
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL,
120137
),
121138
ExportRecipe.get_recipe(
122-
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
139+
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
123140
group_size=32,
124141
),
125142
]
@@ -133,23 +150,22 @@ def forward(self, x) -> torch.Tensor:
133150
example_inputs=example_inputs,
134151
export_recipe=export_recipe,
135152
)
136-
self.assertTrue(
137-
torch.allclose(
138-
session.run_method("forward", example_inputs[0])[0],
139-
model(*example_inputs[0]),
140-
atol=1e-2,
141-
)
142-
)
143153
self.check_fully_delegated(session.get_executorch_program())
154+
self._compare_eager_quantized_model_outputs(
155+
session, example_inputs, 1e-3
156+
)
157+
self._compare_eager_unquantized_model_outputs(
158+
session, model, example_inputs
159+
)
144160

145161
def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType:
146162
# Map QuantType to corresponding recipe name.
147163
if quant_type == QuantType.STATIC_PER_CHANNEL:
148-
return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL
164+
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL
149165
elif quant_type == QuantType.DYNAMIC_PER_CHANNEL:
150-
return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL
166+
return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL
151167
elif quant_type == QuantType.STATIC_PER_TENSOR:
152-
return XNNPackRecipeType.INT8_STATIC_PER_TENSOR
168+
return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR
153169
elif quant_type == QuantType.NONE:
154170
return XNNPackRecipeType.FP32
155171
else:
@@ -224,12 +240,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size(
224240

225241
# Should not raise any exception
226242
recipe_w_default_group = provider.create_recipe(
227-
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
243+
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR
228244
)
229245
self.assertIsNotNone(recipe_w_default_group)
230246

231247
recipe = provider.create_recipe(
232-
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64
248+
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
249+
group_size=64,
233250
)
234251
self.assertIsNotNone(recipe)
235252

@@ -240,7 +257,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size(
240257

241258
with self.assertRaises(ValueError) as cm:
242259
provider.create_recipe(
243-
XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
260+
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
244261
group_size="32", # String instead of int
245262
)
246263

export/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
"""
1616

1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType
18+
from .recipe import (
19+
AOQuantizationConfig,
20+
ExportRecipe,
21+
LoweringRecipe,
22+
QuantizationRecipe,
23+
RecipeType,
24+
)
1925
from .recipe_provider import BackendRecipeProvider
2026
from .recipe_registry import recipe_registry
2127
from .types import StageType
2228

2329
__all__ = [
30+
"AOQuantizationConfig",
2431
"StageType",
2532
"ExportRecipe",
2633
"LoweringRecipe",

export/recipe.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from abc import ABCMeta, abstractmethod
77
from dataclasses import dataclass
88
from enum import Enum, EnumMeta
9-
from typing import List, Optional, Sequence
9+
from typing import Callable, List, Optional, Sequence
10+
11+
import torch
1012

1113
from executorch.exir._warnings import experimental
1214

@@ -64,6 +66,20 @@ class Mode(str, Enum):
6466
RELEASE = "release"
6567

6668

69+
@dataclass
70+
class AOQuantizationConfig:
71+
"""
72+
Configuration for torchao quantization with optional filter function.
73+
74+
Attributes:
75+
ao_base_config: The AOBaseConfig for quantization
76+
filter_fn: Optional filter function to selectively apply quantization
77+
"""
78+
79+
ao_base_config: AOBaseConfig
80+
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None
81+
82+
6783
@dataclass
6884
class QuantizationRecipe:
6985
"""
@@ -73,11 +89,12 @@ class QuantizationRecipe:
7389
7490
Attributes:
7591
quantizers: Optional list of quantizers for model quantization
76-
ao_base_config: Optional list of AO base configurations
92+
ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair
93+
AOBaseConfig with optional filter functions
7794
"""
7895

7996
quantizers: Optional[List[Quantizer]] = None
80-
ao_base_config: Optional[List[AOBaseConfig]] = None
97+
ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None
8198

8299
def get_quantizers(self) -> Optional[List[Quantizer]]:
83100
"""

0 commit comments

Comments
 (0)