Skip to content

Commit 4c0aefa

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W support and test for tanh operation
Summary: Add 16A8W quantization support and test for the tanh operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, and sigmoid operations, extending int16 support to tanh operations. Changes: - Add INT16 dtype validation support in op_tanh.py - Add test_tanh_tensor_16a8w_tosa_INT test function - Enable test_tanh.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: D80510815
1 parent 2889a9a commit 4c0aefa

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

backends/arm/test/ops/test_tanh.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,21 @@
77
from typing import Tuple
88

99
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
get_symmetric_a16w8_quantization_config,
12+
TOSAQuantizer,
13+
)
1014

11-
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test import common, conftest
1216
from executorch.backends.arm.test.tester.test_pipeline import (
1317
EthosU55PipelineINT,
1418
EthosU85PipelineINT,
1519
TosaPipelineFP,
1620
TosaPipelineINT,
1721
VgfPipeline,
1822
)
23+
from executorch.backends.arm.tosa_specification import TosaSpecification
24+
from executorch.backends.xnnpack.test.tester import Quantize
1925

2026
aten_op = "torch.ops.aten.tanh.default"
2127
input_t1 = Tuple[torch.Tensor] # Input x
@@ -105,3 +111,48 @@ def test_tanh_vgf_INT(test_data: Tuple):
105111
tosa_version="TOSA-1.0+INT",
106112
)
107113
pipeline.run()
114+
115+
116+
def get_symmetric_a16w8_tanh_quantizer(
117+
u55_config=False, per_channel_quantization=False
118+
):
119+
tosa_version = conftest.get_option("tosa_version")
120+
tosa_profiles = {
121+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
122+
}
123+
124+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
125+
quantizer.set_global(
126+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
127+
)
128+
129+
return Quantize(
130+
quantizer,
131+
get_symmetric_a16w8_quantization_config(
132+
is_per_channel=per_channel_quantization
133+
),
134+
)
135+
136+
137+
@common.parametrize("test_data", test_data_suite)
138+
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
139+
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
140+
per_channel_quantization = False
141+
142+
pipeline = TosaPipelineINT[input_t1](
143+
Tanh(),
144+
(test_data(),),
145+
aten_op,
146+
exir_op=[],
147+
per_channel_quantization=per_channel_quantization,
148+
use_to_edge_transform_and_lower=True,
149+
tosa_extensions=["int16"],
150+
)
151+
152+
pipeline.change_args(
153+
"quantize",
154+
get_symmetric_a16w8_tanh_quantizer(
155+
per_channel_quantization=per_channel_quantization
156+
),
157+
)
158+
pipeline.run()

0 commit comments

Comments
 (0)