Skip to content

Commit add30f6

Browse files
pytorchbotNinja91
andauthored
Arm backend: Add 16A8W support and test for tanh operation (#14214)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13797 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/10/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/10/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/9/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/10/orig @diff-train-skip-merge --------- Co-authored-by: Nitin Jain <[email protected]>
1 parent 6b9c0a6 commit add30f6

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

backends/arm/test/ops/test_tanh.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@
66

77
from typing import Tuple
88

9+
import pytest
910
import torch
11+
from executorch.backends.arm.quantizer.arm_quantizer import (
12+
get_symmetric_a16w8_quantization_config,
13+
TOSAQuantizer,
14+
)
1015

11-
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test import common, conftest
1217
from executorch.backends.arm.test.tester.test_pipeline import (
1318
EthosU55PipelineINT,
1419
EthosU85PipelineINT,
1520
TosaPipelineFP,
1621
TosaPipelineINT,
1722
VgfPipeline,
1823
)
24+
from executorch.backends.arm.tosa.specification import TosaSpecification
25+
from executorch.backends.xnnpack.test.tester import Quantize
1926

2027
aten_op = "torch.ops.aten.tanh.default"
2128
input_t1 = Tuple[torch.Tensor] # Input x
@@ -105,3 +112,105 @@ def test_tanh_vgf_INT(test_data: Tuple):
105112
tosa_version="TOSA-1.0+INT",
106113
)
107114
pipeline.run()
115+
116+
117+
def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False):
118+
tosa_version = conftest.get_option("tosa_version")
119+
tosa_profiles = {
120+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121+
}
122+
123+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
quantizer.set_global(
125+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126+
)
127+
128+
return Quantize(
129+
quantizer,
130+
get_symmetric_a16w8_quantization_config(
131+
is_per_channel=per_channel_quantization
132+
),
133+
)
134+
135+
136+
@common.parametrize("test_data", test_data_suite)
137+
@pytest.mark.xfail(
138+
reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975"
139+
)
140+
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
141+
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
142+
per_channel_quantization = False
143+
144+
pipeline = TosaPipelineINT[input_t1](
145+
Tanh(),
146+
(test_data(),),
147+
aten_op,
148+
exir_op=[],
149+
per_channel_quantization=per_channel_quantization,
150+
use_to_edge_transform_and_lower=True,
151+
tosa_extensions=["int16"],
152+
)
153+
154+
pipeline.change_args(
155+
"quantize",
156+
get_symmetric_a16w8_tanh_quantizer(
157+
per_channel_quantization=per_channel_quantization
158+
),
159+
)
160+
pipeline.run()
161+
162+
163+
@common.parametrize("test_data", test_data_suite)
164+
@common.XfailIfNoCorstone300
165+
@pytest.mark.xfail(
166+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
167+
)
168+
def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
169+
"""Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
170+
per_channel_quantization = False
171+
172+
pipeline = EthosU55PipelineINT[input_t1](
173+
Tanh(),
174+
(test_data(),),
175+
aten_op,
176+
exir_ops=[],
177+
per_channel_quantization=per_channel_quantization,
178+
use_to_edge_transform_and_lower=True,
179+
run_on_fvp=True,
180+
)
181+
182+
pipeline.change_args(
183+
"quantize",
184+
get_symmetric_a16w8_tanh_quantizer(
185+
per_channel_quantization=per_channel_quantization
186+
),
187+
)
188+
pipeline.run()
189+
190+
191+
@common.parametrize("test_data", test_data_suite)
192+
@common.XfailIfNoCorstone320
193+
@pytest.mark.xfail(
194+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
195+
)
196+
def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor):
197+
"""Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
198+
per_channel_quantization = False
199+
200+
pipeline = EthosU85PipelineINT[input_t1](
201+
Tanh(),
202+
(test_data(),),
203+
aten_op,
204+
exir_ops=[],
205+
per_channel_quantization=per_channel_quantization,
206+
use_to_edge_transform_and_lower=True,
207+
run_on_fvp=True,
208+
)
209+
210+
pipeline.change_args(
211+
"quantize",
212+
get_symmetric_a16w8_tanh_quantizer(
213+
per_channel_quantization=per_channel_quantization
214+
),
215+
)
216+
pipeline.run()

0 commit comments

Comments
 (0)