Skip to content

Commit 353da24

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for add operation"
Add 16A8W quantization support and comprehensive tests for the add operation in ExecutorTorch ARM backend targeting Ethos U55 and U85 NPUs. This follows the pattern established for linear operations, extending int16 support to add operations with hardware-specific testing. Changes: - Add INT16 dtype validation support in op_add.py - Add test_add_tensor_16a8w_tosa_INT test function with U55/U85 pipeline support - Add U55 and U85 specific 16A8W tests with proper xfail decorators - Fix U55/U85 test parameter usage (remove unsupported tosa_extensions, clean quantizer function calls) - Update xfail reasons to consistent 'Vela compilation fails with Invalid arguments' pattern Differential Revision: [D80510463](https://our.internmc.facebook.com/intern/diff/D80510463) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents acd368f + 445c3f5 commit 353da24

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+541
-1772
lines changed

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[
48+
ts.DType.INT8,
49+
ts.DType.INT16,
50+
ts.DType.INT32,
51+
ts.DType.FP32,
52+
ts.DType.BOOL,
53+
],
4854
output.tosa_spec,
4955
)
5056

backends/arm/quantizer/arm_quantizer.py

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -225,86 +225,6 @@ def get_symmetric_a16w8_quantization_config(
225225
return quantization_config
226226

227227

228-
@functools.lru_cache
229-
def get_symmetric_a16w8_quantization_config(
230-
is_per_channel: bool = True,
231-
is_qat: bool = False,
232-
is_dynamic: bool = False,
233-
weight_qmin: int = -127,
234-
weight_qmax: int = 127,
235-
):
236-
"""
237-
16A8W quantization config: 16-bit activations, 8-bit weights.
238-
239-
This configuration provides better accuracy than 8A8W while maintaining
240-
reasonable memory usage through 8-bit weights.
241-
242-
Args:
243-
is_per_channel: Whether to use per-channel quantization for weights
244-
is_qat: Whether this is for Quantization Aware Training
245-
is_dynamic: Whether to use dynamic quantization
246-
weight_qmin: Minimum quantization value for weights
247-
weight_qmax: Maximum quantization value for weights
248-
249-
Returns:
250-
QuantizationConfig with 16-bit activations and 8-bit weights
251-
"""
252-
extra_args: Dict[str, Any] = {"eps": 2**-12}
253-
254-
# Setup observer/fake-quant for 16-bit activations
255-
if is_qat:
256-
if is_dynamic:
257-
act_observer_or_fake_quant_ctr = FakeQuantize
258-
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
259-
averaging_constant=1
260-
)
261-
extra_args["observer"] = dynamic_quant_observer
262-
else:
263-
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
264-
else:
265-
if is_dynamic:
266-
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
267-
else:
268-
# HistogramObserver works well for 16-bit range
269-
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
270-
271-
# 16-bit activation quantization spec
272-
act_quantization_spec = QuantizationSpec(
273-
dtype=torch.int16,
274-
quant_min=torch.iinfo(torch.int16).min, # -32768
275-
quant_max=torch.iinfo(torch.int16).max, # 32767
276-
qscheme=torch.per_tensor_symmetric,
277-
is_dynamic=is_dynamic,
278-
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
279-
**extra_args,
280-
),
281-
)
282-
283-
# Instead of reconstructing quantization_config, just clone and update as needed
284-
# Clone the quantization_config from get_symmetric_quantization_config and update activation spec
285-
base_config = get_symmetric_quantization_config(
286-
is_per_channel=is_per_channel,
287-
is_qat=is_qat,
288-
is_dynamic=is_dynamic,
289-
)
290-
# Replace activation quantization spec with 16-bit version
291-
if is_dynamic:
292-
quantization_config = QuantizationConfig(
293-
act_quantization_spec, # 16-bit input activations
294-
None,
295-
base_config.weight, # 8-bit weights from base config
296-
None,
297-
)
298-
else:
299-
quantization_config = QuantizationConfig(
300-
act_quantization_spec, # 16-bit input activations
301-
act_quantization_spec, # 16-bit output activations
302-
base_config.weight, # 8-bit weights from base config
303-
None,
304-
)
305-
return quantization_config
306-
307-
308228
NodeFilterType = Callable[[Node], bool]
309229
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
310230
a Node and returns whether the node should be annotated or not.

backends/arm/scripts/utils.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ function check_os_support() {
8484
# Linux on arm64/aarch64
8585
# Darwin on arm64/aarch64
8686
if [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then
87-
if [[ "${OS}" != "Darwin" ]] || [[ "${OS}" != "Linux" ]]; then
87+
if [[ "${OS}" != "Darwin" ]] && [[ "${OS}" != "Linux" ]]; then
8888
echo "Error: Only Linux and Darwin are supported on arm64"
8989
exit 1
9090
fi

backends/arm/test/ops/test_add.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False):
261261

262262
@common.parametrize("test_data", Add.test_data)
263263
@pytest.mark.xfail(
264-
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank"
264+
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730"
265265
)
266266
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
267267
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
@@ -289,7 +289,7 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
289289
@common.parametrize("test_data", Add.test_data)
290290
@common.XfailIfNoCorstone300
291291
@pytest.mark.xfail(
292-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations"
292+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
293293
)
294294
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
295295
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
@@ -317,7 +317,7 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
317317
@common.parametrize("test_data", Add.test_data)
318318
@common.XfailIfNoCorstone320
319319
@pytest.mark.xfail(
320-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations"
320+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
321321
)
322322
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
323323
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""

backends/arm/test/ops/test_linear.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Tuple
1010

1111
import pytest
12-
1312
import torch
14-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
1720
EthosU55PipelineINT,
@@ -20,6 +23,8 @@
2023
TosaPipelineINT,
2124
VgfPipeline,
2225
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2328

2429
aten_op = "torch.ops.aten.linear.default"
2530

@@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143148
pipeline.run()
144149

145150

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147151
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148152
def test_linear_tosa_INT(test_data: torch.Tensor):
149153
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -243,3 +247,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
243247
per_channel_quantization=per_channel_quantization,
244248
)
245249
pipeline.run()
250+
251+
252+
def get_symmetric_a16w8_linear_quantizer(
253+
u55_config=False, per_channel_quantization=False
254+
):
255+
tosa_version = conftest.get_option("tosa_version")
256+
tosa_profiles = {
257+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
258+
}
259+
260+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
261+
quantizer.set_global(
262+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
263+
)
264+
quantizer.set_module_type(
265+
torch.nn.Linear,
266+
get_symmetric_a16w8_quantization_config(
267+
is_per_channel=per_channel_quantization
268+
),
269+
)
270+
271+
return Quantize(
272+
quantizer,
273+
get_symmetric_a16w8_quantization_config(
274+
is_per_channel=per_channel_quantization
275+
),
276+
)
277+
278+
279+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
280+
@pytest.mark.xfail(
281+
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
282+
)
283+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
284+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
285+
test_data, out_features, has_bias, per_channel_quantization = test_data()
286+
in_features = test_data.shape[-1]
287+
288+
# Create pipeline with custom 16A8W quantization config
289+
pipeline = TosaPipelineINT[input_t1](
290+
Linear(
291+
in_features=in_features,
292+
out_features=out_features,
293+
bias=has_bias,
294+
),
295+
(test_data,),
296+
aten_op,
297+
exir_op=[],
298+
per_channel_quantization=per_channel_quantization,
299+
use_to_edge_transform_and_lower=True,
300+
tosa_extensions=["int16"],
301+
)
302+
303+
pipeline.change_args(
304+
"quantize",
305+
get_symmetric_a16w8_linear_quantizer(
306+
per_channel_quantization=per_channel_quantization
307+
),
308+
)
309+
# Run the pipeline
310+
pipeline.run()

backends/cadence/aot/ref_implementations.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from typing import Optional
10+
911
import torch
1012
from executorch.exir.scalar_type import ScalarType
1113
from torch.library import impl, Library
@@ -177,6 +179,54 @@ def quantized_add(
177179
)
178180

179181

182+
@impl(m, "quantized_linear")
183+
def quantized_linear(
184+
src: torch.Tensor,
185+
weight: torch.Tensor,
186+
bias: torch.Tensor,
187+
in_zero_point: int,
188+
weight_zero_point: torch.Tensor,
189+
out_multiplier: torch.Tensor,
190+
out_shift: torch.Tensor,
191+
out_zero_point: int,
192+
offset: Optional[torch.Tensor],
193+
) -> torch.Tensor:
194+
"""
195+
Quantized linear (transposed matmul) operation.
196+
197+
Args:
198+
- src (Tensor): The activations tensor
199+
- weight (Tensor): The weight tensor
200+
- bias (Tensor): The bias tensor
201+
- in_zero_point (int): The quantized mapping of zero for the input
202+
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
203+
- out_multiplier (Tensor): The multiplier used to scale the output
204+
- out_shift (Tensor): The shift used to scale the output
205+
- out_zero_point (int): The quantized mapping of zero for the output
206+
- offset (Tensor): Unused
207+
"""
208+
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
209+
210+
N, K = weight.shape
211+
212+
leading_dims = src.shape[:-1]
213+
src = src.view(-1, K)
214+
215+
dtype = src.dtype
216+
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
217+
if dtype not in supported_dtypes:
218+
raise ValueError(
219+
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"
220+
)
221+
222+
out = torch.nn.functional.linear(
223+
src - in_zero_point, weight - weight_zero_point, bias
224+
)
225+
return quantize_per_tensor(
226+
out, out_scale, out_zero_point, -128, 127, dtype
227+
).reshape(*leading_dims, N)
228+
229+
180230
@impl(m, "requantize")
181231
def requantize(
182232
input: torch.Tensor,

0 commit comments

Comments
 (0)