Skip to content

Commit 07879f1

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W quantization configuration utility for ARM backend (#13175)
Summary: Pull Request resolved: #13175 This diff implements a 16A8W (16-bit activations, 8-bit weights) quantization configuration utility for the ExecutorTorch ARM backend, following the feedback from D79746479. ## Key Changes **1. New Quantization Configuration Function** - Add `get_16a8w_quantization_config()` in `fbcode/executorch/backends/arm/quantizer/arm_quantizer.py` - Provides 16-bit activations with HistogramObserver (better precision than 8A8W) - Maintains 8-bit weights with MinMaxObserver/PerChannelMinMaxObserver (memory efficient) - **Technically supported by TOSA through [EXT-INT16 extension/profile](https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d)** **2. Test Implementation** - Add `test_linear_16a8w_tosa_INT()` test in `fbcode/executorch/backends/arm/test/ops/test_linear.py` - Demonstrates usage of new 16A8W quantization configuration ## Benefits - **Better Precision**: 16-bit activations provide higher precision than 8-bit. Useful for carrying precision for recurring neural nets. - **Configurable**: Supports same parameters as existing quantization configurations ## Testing The implementation provides the utility function and test infrastructure. Note: The test reveals that TOSA backend has limited INT16 support for some operations (view operations only support INT8/INT32/FP32/BOOL), which is expected and shows the configuration correctly produces INT16 tensors. Differential Revision: D79763381
1 parent 016eece commit 07879f1

File tree

3 files changed

+144
-6
lines changed

3 files changed

+144
-6
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,111 @@ def get_symmetric_quantization_config(
144144
return quantization_config
145145

146146

147+
@functools.lru_cache
148+
def get_16a8w_quantization_config(
149+
is_per_channel: bool = True,
150+
is_qat: bool = False,
151+
is_dynamic: bool = False,
152+
weight_qmin: int = -127,
153+
weight_qmax: int = 127,
154+
):
155+
"""
156+
16A8W quantization config: 16-bit activations, 8-bit weights.
157+
158+
This configuration provides better accuracy than 8A8W while maintaining
159+
reasonable memory usage through 8-bit weights.
160+
161+
Args:
162+
is_per_channel: Whether to use per-channel quantization for weights
163+
is_qat: Whether this is for Quantization Aware Training
164+
is_dynamic: Whether to use dynamic quantization
165+
weight_qmin: Minimum quantization value for weights
166+
weight_qmax: Maximum quantization value for weights
167+
168+
Returns:
169+
QuantizationConfig with 16-bit activations and 8-bit weights
170+
"""
171+
extra_args: Dict[str, Any] = {"eps": 2**-12}
172+
173+
# Setup observer/fake-quant for 16-bit activations
174+
if is_qat:
175+
if is_dynamic:
176+
act_observer_or_fake_quant_ctr = FakeQuantize
177+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
178+
averaging_constant=1
179+
)
180+
extra_args["observer"] = dynamic_quant_observer
181+
else:
182+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
183+
else:
184+
if is_dynamic:
185+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
186+
else:
187+
# HistogramObserver works well for 16-bit range
188+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
189+
190+
# 16-bit activation quantization spec
191+
act_quantization_spec = QuantizationSpec(
192+
dtype=torch.int32,
193+
quant_min=torch.iinfo(torch.int16).min, # -32768
194+
quant_max=torch.iinfo(torch.int16).max, # 32767
195+
qscheme=torch.per_tensor_affine,
196+
is_dynamic=is_dynamic,
197+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
198+
**extra_args,
199+
),
200+
)
201+
202+
# Setup quantization config for weights (same as 8A8W - use 8-bit weights)
203+
weight_qscheme = (
204+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
205+
)
206+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
207+
MinMaxObserver
208+
)
209+
# Determine the right observer/fake-quant constructor
210+
if is_qat:
211+
# Set plain fake-quant with true min/max
212+
weight_observer_or_fake_quant_ctr = FakeQuantize
213+
else:
214+
# PTQ: set min/max observer
215+
weight_observer_or_fake_quant_ctr = (
216+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
217+
)
218+
219+
weight_extra_args = {"eps": 2**-12}
220+
221+
# 8-bit weight quantization spec (keep weights at 8-bit for memory efficiency)
222+
weight_quantization_spec = QuantizationSpec(
223+
dtype=torch.int8,
224+
quant_min=weight_qmin,
225+
quant_max=weight_qmax,
226+
qscheme=weight_qscheme,
227+
ch_axis=0,
228+
is_dynamic=False,
229+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
230+
**weight_extra_args
231+
),
232+
)
233+
234+
bias_quantization_spec = None
235+
if is_dynamic:
236+
quantization_config = QuantizationConfig(
237+
act_quantization_spec, # 16-bit input activations
238+
None,
239+
weight_quantization_spec, # 8-bit weights
240+
bias_quantization_spec,
241+
)
242+
else:
243+
quantization_config = QuantizationConfig(
244+
act_quantization_spec, # 16-bit input activations
245+
act_quantization_spec, # 16-bit output activations
246+
weight_quantization_spec, # 8-bit weights
247+
bias_quantization_spec,
248+
)
249+
return quantization_config
250+
251+
147252
NodeFilterType = Callable[[Node], bool]
148253
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
149254
a Node and returns whether the node should be annotated or not.
@@ -217,7 +322,6 @@ def not_module_type_or_name_filter(n: Node) -> bool:
217322

218323

219324
class TOSAQuantizer(Quantizer):
220-
221325
def __init__(
222326
self, compile_spec_or_tosa_spec: Union[TosaSpecification, List[CompileSpec]]
223327
) -> None:

backends/arm/test/ops/test_linear.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import pytest
1212

1313
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_16a8w_quantization_config,
16+
)
1417
from executorch.backends.arm.test import common
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -258,3 +261,33 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258261
per_channel_quantization=per_channel_quantization,
259262
)
260263
pipeline.run()
264+
265+
266+
@pytest.mark.xfail(
267+
reason="TOSA backend has limited INT16 support - view operations only support INT8/INT32/FP32/BOOL"
268+
)
269+
@common.parametrize("test_data", test_data_rank1_INT)
270+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
271+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
272+
test_data, out_features, has_bias, per_channel_quantization = test_data()
273+
in_features = test_data.shape[-1]
274+
275+
# Create pipeline with custom 16A8W quantization config
276+
pipeline = TosaPipelineINT[input_t1](
277+
Linear(
278+
in_features=in_features,
279+
out_features=out_features,
280+
bias=has_bias,
281+
),
282+
(test_data,),
283+
aten_op,
284+
exir_op=[],
285+
per_channel_quantization=per_channel_quantization,
286+
use_to_edge_transform_and_lower=True,
287+
quantization_config=get_16a8w_quantization_config(
288+
is_per_channel=per_channel_quantization
289+
),
290+
)
291+
292+
# Run the pipeline
293+
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def __init__(
107107
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
108108
] = None,
109109
):
110-
111110
self.tester = ArmTester(
112111
module,
113112
example_inputs=test_data,
@@ -306,6 +305,7 @@ def __init__(
306305
rtol: float = 1e-03,
307306
qtol: int = 1,
308307
dynamic_shapes: Optional[Tuple[Any]] = None,
308+
quantization_config: Optional[Any] = None,
309309
):
310310
tosa_profiles = {
311311
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -317,9 +317,11 @@ def __init__(
317317
)
318318

319319
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
320-
quantization_config = get_symmetric_quantization_config(
321-
is_per_channel=per_channel_quantization
322-
)
320+
# Use custom quantization config if provided, otherwise use default
321+
if quantization_config is None:
322+
quantization_config = get_symmetric_quantization_config(
323+
is_per_channel=per_channel_quantization
324+
)
323325
if symmetric_io_quantization:
324326
quantizer.set_io(quantization_config)
325327
quant_stage = Quantize(quantizer, quantization_config)
@@ -856,7 +858,6 @@ def __init__(
856858
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
857859
] = None,
858860
):
859-
860861
tosa_profile = TosaSpecification.create_from_string(tosa_version)
861862
compile_spec = common.get_vgf_compile_spec(
862863
tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path

0 commit comments

Comments
 (0)