Skip to content

Commit e19d555

Browse files
pytorchbotNinja91
andauthored
Arm backend: Add 16A8W support and test for add operation (#14039)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13789 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/5/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/5/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/5/orig @diff-train-skip-merge Co-authored-by: Nitin Jain <[email protected]>
1 parent c1ea7e9 commit e19d555

File tree

3 files changed

+117
-4
lines changed

3 files changed

+117
-4
lines changed

backends/arm/operators/op_add.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ def define_node(
4747

4848
validate_num_inputs(self.target, inputs, 2)
4949
validate_same_dtype(self.target, [*inputs, output], ts)
50+
valid_dtypes = []
51+
if self.tosa_spec.support_integer():
52+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
53+
if self.tosa_spec.support_float():
54+
valid_dtypes.extend([ts.DType.INT32])
55+
5056
validate_valid_dtype(
5157
self.target,
5258
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
59+
valid_dtypes,
5460
output.tosa_spec,
5561
)
5662
scale_back = 1.0
@@ -59,15 +65,15 @@ def define_node(
5965
tosa_graph, inputs, node, self.tosa_spec
6066
)
6167
else:
62-
# input[0].dtype == ts.DType.INT32
68+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6369
# Non quantized input, natively support by TOSA.ADD
6470
rescaled_inputs = inputs
6571

6672
if output.dtype == ts.DType.INT8:
6773
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6874
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6975
else:
70-
# output.dtype == ts.DType.INT32
76+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
7177
add_output = output
7278

7379
input1, input2 = rescaled_inputs
@@ -117,7 +123,7 @@ def define_node(
117123
validate_num_inputs(self.target, inputs, 2)
118124
validate_same_dtype(self.target, [*inputs, output], ts)
119125

120-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
126+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
121127
# Call the inherited define_node for handling integers
122128
super().define_node(node, tosa_graph, inputs, output)
123129
else:

backends/arm/test/ops/test_add.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import pytest
1111
import torch
1212
from executorch.backends.arm.quantizer import arm_quantizer
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1317
from executorch.backends.arm.test import common, conftest
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
@@ -235,3 +239,105 @@ def test_add_tensor_vgf_INT(test_data: input_t1):
235239
pipeline.run()
236240
except FileNotFoundError as e:
237241
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
242+
243+
244+
def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False):
245+
tosa_version = conftest.get_option("tosa_version")
246+
tosa_profiles = {
247+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
248+
}
249+
250+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
251+
quantizer.set_global(
252+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
253+
)
254+
255+
return Quantize(
256+
quantizer,
257+
get_symmetric_a16w8_quantization_config(
258+
is_per_channel=per_channel_quantization
259+
),
260+
)
261+
262+
263+
@common.parametrize("test_data", Add.test_data)
264+
@pytest.mark.xfail(
265+
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730"
266+
)
267+
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
268+
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
269+
per_channel_quantization = False
270+
271+
pipeline = TosaPipelineINT[input_t1](
272+
Add(),
273+
test_data(),
274+
aten_op,
275+
exir_op=[],
276+
per_channel_quantization=per_channel_quantization,
277+
use_to_edge_transform_and_lower=True,
278+
tosa_extensions=["int16"],
279+
)
280+
281+
pipeline.change_args(
282+
"quantize",
283+
get_symmetric_a16w8_add_quantizer(
284+
per_channel_quantization=per_channel_quantization
285+
),
286+
)
287+
pipeline.run()
288+
289+
290+
@common.parametrize("test_data", Add.test_data)
291+
@common.XfailIfNoCorstone300
292+
@pytest.mark.xfail(
293+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
294+
)
295+
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
296+
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
297+
per_channel_quantization = False
298+
299+
pipeline = EthosU55PipelineINT[input_t1](
300+
Add(),
301+
test_data(),
302+
aten_op,
303+
exir_op,
304+
per_channel_quantization=per_channel_quantization,
305+
use_to_edge_transform_and_lower=True,
306+
run_on_fvp=True,
307+
)
308+
309+
pipeline.change_args(
310+
"quantize",
311+
get_symmetric_a16w8_add_quantizer(
312+
per_channel_quantization=per_channel_quantization
313+
),
314+
)
315+
pipeline.run()
316+
317+
318+
@common.parametrize("test_data", Add.test_data)
319+
@common.XfailIfNoCorstone320
320+
@pytest.mark.xfail(
321+
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
322+
)
323+
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
324+
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
325+
per_channel_quantization = False
326+
327+
pipeline = EthosU85PipelineINT[input_t1](
328+
Add(),
329+
test_data(),
330+
aten_op,
331+
exir_op,
332+
per_channel_quantization=per_channel_quantization,
333+
use_to_edge_transform_and_lower=True,
334+
run_on_fvp=True,
335+
)
336+
337+
pipeline.change_args(
338+
"quantize",
339+
get_symmetric_a16w8_add_quantizer(
340+
per_channel_quantization=per_channel_quantization
341+
),
342+
)
343+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_arm_tests():
1313

1414
# Operators
1515
test_files += [
16+
"ops/test_add.py",
1617
"ops/test_avg_pool2d.py",
1718
"ops/test_linear.py",
1819
"ops/test_slice.py",

0 commit comments

Comments
 (0)