Skip to content

Commit 3038914

Browse files
committed
Add 16A8W support and test for add operation
Pull Request resolved: #13789 Add 16A8W quantization support and comprehensive tests for the add operation in ExecutorTorch ARM backend targeting Ethos U55 and U85 NPUs. This follows the pattern established for linear operations, extending int16 support to add operations with hardware-specific testing. Changes: - Add INT16 dtype validation support in op_add.py - Add test_add_tensor_16a8w_tosa_INT test function with U55/U85 pipeline support - Add U55 and U85 specific 16A8W tests with proper xfail decorators - Fix U55/U85 test parameter usage (remove unsupported tosa_extensions, clean quantizer function calls) - Update xfail reasons to consistent 'Vela compilation fails with Invalid arguments' pattern ghstack-source-id: 308053642 ghstack-source-id: 308053642 @exported-using-ghexport @bypass-github-pytorch-ci-checks @bypass-github-pytorch-ci-checks @bypass-github-executorch-ci-checks Differential Revision: [D80510463](https://our.internmc.facebook.com/intern/diff/D80510463/)
1 parent 1a7441f commit 3038914

File tree

3 files changed

+117
-4
lines changed

3 files changed

+117
-4
lines changed

backends/arm/operators/op_add.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ def define_node(
4747

4848
validate_num_inputs(self.target, inputs, 2)
4949
validate_same_dtype(self.target, [*inputs, output], ts)
50+
valid_dtypes = []
51+
if self.tosa_spec.support_integer():
52+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
53+
if self.tosa_spec.support_float():
54+
valid_dtypes.extend([ts.DType.INT32])
55+
5056
validate_valid_dtype(
5157
self.target,
5258
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
59+
valid_dtypes,
5460
output.tosa_spec,
5561
)
5662
scale_back = 1.0
@@ -59,15 +65,15 @@ def define_node(
5965
tosa_graph, inputs, node, self.tosa_spec
6066
)
6167
else:
62-
# input[0].dtype == ts.DType.INT32
68+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6369
# Non quantized input, natively support by TOSA.ADD
6470
rescaled_inputs = inputs
6571

6672
if output.dtype == ts.DType.INT8:
6773
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6874
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6975
else:
70-
# output.dtype == ts.DType.INT32
76+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
7177
add_output = output
7278

7379
input1, input2 = rescaled_inputs
@@ -117,7 +123,7 @@ def define_node(
117123
validate_num_inputs(self.target, inputs, 2)
118124
validate_same_dtype(self.target, [*inputs, output], ts)
119125

120-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
126+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
121127
# Call the inherited define_node for handling integers
122128
super().define_node(node, tosa_graph, inputs, output)
123129
else:

backends/arm/test/ops/test_add.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import pytest
1111
import torch
1212
from executorch.backends.arm.quantizer import arm_quantizer
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1317
from executorch.backends.arm.test import common, conftest
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
@@ -235,3 +239,105 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
235239
pipeline.run()
236240
except FileNotFoundError as e:
237241
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
242+
243+
244+
def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False):
245+
tosa_version = conftest.get_option("tosa_version")
246+
tosa_profiles = {
247+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
248+
}
249+
250+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
251+
quantizer.set_global(
252+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
253+
)
254+
255+
return Quantize(
256+
quantizer,
257+
get_symmetric_a16w8_quantization_config(
258+
is_per_channel=per_channel_quantization
259+
),
260+
)
261+
262+
263+
@common.parametrize("test_data", Add.test_data)
264+
@pytest.mark.xfail(
265+
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730"
266+
)
267+
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
268+
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
269+
per_channel_quantization = False
270+
271+
pipeline = TosaPipelineINT[input_t1](
272+
Add(),
273+
test_data(),
274+
aten_op,
275+
exir_op=[],
276+
per_channel_quantization=per_channel_quantization,
277+
use_to_edge_transform_and_lower=True,
278+
tosa_extensions=["int16"],
279+
)
280+
281+
pipeline.change_args(
282+
"quantize",
283+
get_symmetric_a16w8_add_quantizer(
284+
per_channel_quantization=per_channel_quantization
285+
),
286+
)
287+
pipeline.run()
288+
289+
290+
@common.parametrize("test_data", Add.test_data)
291+
@common.XfailIfNoCorstone300
292+
@pytest.mark.xfail(
293+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
294+
)
295+
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
296+
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
297+
per_channel_quantization = False
298+
299+
pipeline = EthosU55PipelineINT[input_t1](
300+
Add(),
301+
test_data(),
302+
aten_op,
303+
exir_op,
304+
per_channel_quantization=per_channel_quantization,
305+
use_to_edge_transform_and_lower=True,
306+
run_on_fvp=True,
307+
)
308+
309+
pipeline.change_args(
310+
"quantize",
311+
get_symmetric_a16w8_add_quantizer(
312+
per_channel_quantization=per_channel_quantization
313+
),
314+
)
315+
pipeline.run()
316+
317+
318+
@common.parametrize("test_data", Add.test_data)
319+
@common.XfailIfNoCorstone320
320+
@pytest.mark.xfail(
321+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
322+
)
323+
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
324+
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
325+
per_channel_quantization = False
326+
327+
pipeline = EthosU85PipelineINT[input_t1](
328+
Add(),
329+
test_data(),
330+
aten_op,
331+
exir_op,
332+
per_channel_quantization=per_channel_quantization,
333+
use_to_edge_transform_and_lower=True,
334+
run_on_fvp=True,
335+
)
336+
337+
pipeline.change_args(
338+
"quantize",
339+
get_symmetric_a16w8_add_quantizer(
340+
per_channel_quantization=per_channel_quantization
341+
),
342+
)
343+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_arm_tests():
1313

1414
# Operators
1515
test_files += [
16+
"ops/test_add.py",
1617
"ops/test_avg_pool2d.py",
1718
"ops/test_linear.py",
1819
"ops/test_slice.py",

0 commit comments

Comments
 (0)