Skip to content

Commit 1a048c5

Browse files
committed
Add 16A8W support and test for tanh operation
Add 16A8W quantization support and test for the tanh operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, and sigmoid operations, extending int16 support to tanh operations. Changes: - Add INT16 dtype validation support in op_tanh.py - Add test_tanh_tensor_16a8w_tosa_INT test function - Enable test_tanh.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510815](https://our.internmc.facebook.com/intern/diff/D80510815/) [ghstack-poisoned]
1 parent a5d989a commit 1a048c5

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

backends/arm/test/ops/test_tanh.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@
66

77
from typing import Tuple
88

9+
import pytest
910
import torch
11+
from executorch.backends.arm.quantizer.arm_quantizer import (
12+
get_symmetric_a16w8_quantization_config,
13+
TOSAQuantizer,
14+
)
1015

11-
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test import common, conftest
1217
from executorch.backends.arm.test.tester.test_pipeline import (
1318
EthosU55PipelineINT,
1419
EthosU85PipelineINT,
1520
TosaPipelineFP,
1621
TosaPipelineINT,
1722
VgfPipeline,
1823
)
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
25+
from executorch.backends.xnnpack.test.tester import Quantize
1926

2027
aten_op = "torch.ops.aten.tanh.default"
2128
input_t1 = Tuple[torch.Tensor] # Input x
@@ -105,3 +112,102 @@ def test_tanh_vgf_INT(test_data: Tuple):
105112
tosa_version="TOSA-1.0+INT",
106113
)
107114
pipeline.run()
115+
116+
117+
def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False):
118+
tosa_version = conftest.get_option("tosa_version")
119+
tosa_profiles = {
120+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121+
}
122+
123+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
quantizer.set_global(
125+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126+
)
127+
128+
return Quantize(
129+
quantizer,
130+
get_symmetric_a16w8_quantization_config(
131+
is_per_channel=per_channel_quantization
132+
),
133+
)
134+
135+
136+
@common.parametrize("test_data", test_data_suite)
137+
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
138+
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
139+
per_channel_quantization = False
140+
141+
pipeline = TosaPipelineINT[input_t1](
142+
Tanh(),
143+
(test_data(),),
144+
aten_op,
145+
exir_op=[],
146+
per_channel_quantization=per_channel_quantization,
147+
use_to_edge_transform_and_lower=True,
148+
tosa_extensions=["int16"],
149+
)
150+
151+
pipeline.change_args(
152+
"quantize",
153+
get_symmetric_a16w8_tanh_quantizer(
154+
per_channel_quantization=per_channel_quantization
155+
),
156+
)
157+
pipeline.run()
158+
159+
160+
@common.parametrize("test_data", test_data_suite)
161+
@common.XfailIfNoCorstone300
162+
@pytest.mark.xfail(
163+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
164+
)
165+
def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
166+
"""Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
167+
per_channel_quantization = False
168+
169+
pipeline = EthosU55PipelineINT[input_t1](
170+
Tanh(),
171+
(test_data(),),
172+
aten_op,
173+
exir_ops=[],
174+
per_channel_quantization=per_channel_quantization,
175+
use_to_edge_transform_and_lower=True,
176+
run_on_fvp=True,
177+
)
178+
179+
pipeline.change_args(
180+
"quantize",
181+
get_symmetric_a16w8_tanh_quantizer(
182+
per_channel_quantization=per_channel_quantization
183+
),
184+
)
185+
pipeline.run()
186+
187+
188+
@common.parametrize("test_data", test_data_suite)
189+
@common.XfailIfNoCorstone320
190+
@pytest.mark.xfail(
191+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
192+
)
193+
def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor):
194+
"""Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
195+
per_channel_quantization = False
196+
197+
pipeline = EthosU85PipelineINT[input_t1](
198+
Tanh(),
199+
(test_data(),),
200+
aten_op,
201+
exir_ops=[],
202+
per_channel_quantization=per_channel_quantization,
203+
use_to_edge_transform_and_lower=True,
204+
run_on_fvp=True,
205+
)
206+
207+
pipeline.change_args(
208+
"quantize",
209+
get_symmetric_a16w8_tanh_quantizer(
210+
per_channel_quantization=per_channel_quantization
211+
),
212+
)
213+
pipeline.run()

0 commit comments

Comments
 (0)