Skip to content

Commit 57ffbf6

Browse files
3l1facebook-github-bot
authored andcommitted
Add tests for int16 rsqrt on Ethos-U55/U85 (#14770)
Summary: Fix Rsqrt op for int16 bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: Ninja91, digantdesai Differential Revision: D83802158
1 parent 7ce78c0 commit 57ffbf6

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

backends/arm/test/ops/test_rsqrt.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1218

13-
from executorch.backends.arm.test import common
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
21-
26+
from executorch.backends.arm.tosa import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2228

2329
aten_op = "torch.ops.aten.rsqrt.default"
2430
input_t1 = Tuple[torch.Tensor] # Input x
@@ -104,3 +110,96 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
104110
tosa_version="TOSA-1.0+INT",
105111
)
106112
pipeline.run()
113+
114+
115+
def get_symmetric_a16w8_rsqrt_quantizer(
116+
u55_config=False, per_channel_quantization=False
117+
):
118+
tosa_version = conftest.get_option("tosa_version")
119+
tosa_profiles = {
120+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121+
}
122+
123+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
quantizer.set_global(
125+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126+
)
127+
128+
return Quantize(
129+
quantizer,
130+
get_symmetric_a16w8_quantization_config(
131+
is_per_channel=per_channel_quantization
132+
),
133+
)
134+
135+
136+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
137+
def test_rsqrt_int16_tosa_INT(test_tensor: torch.Tensor):
138+
"""Test rsqrt operation with int16 quantization"""
139+
pipeline = TosaPipelineINT[input_t1](
140+
Rsqrt(),
141+
test_tensor(),
142+
aten_op,
143+
exir_op=[],
144+
per_channel_quantization=False,
145+
use_to_edge_transform_and_lower=True,
146+
tosa_extensions=["int16"],
147+
)
148+
149+
pipeline.change_args(
150+
"quantize",
151+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
152+
)
153+
# Run the pipeline
154+
pipeline.run()
155+
156+
157+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
158+
@common.XfailIfNoCorstone300
159+
@pytest.mark.xfail(
160+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
161+
)
162+
def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
163+
"""Test rsqrt operation with int16 quantization on U55"""
164+
pipeline = EthosU55PipelineINT[input_t1](
165+
Rsqrt(),
166+
test_tensor(),
167+
aten_op,
168+
exir_ops=[],
169+
per_channel_quantization=True,
170+
use_to_edge_transform_and_lower=True,
171+
atol=1e-03,
172+
rtol=1e-03,
173+
run_on_fvp=True,
174+
)
175+
176+
pipeline.change_args(
177+
"quantize",
178+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=True),
179+
)
180+
pipeline.run()
181+
182+
183+
@common.parametrize("test_tensor", Rsqrt.test_parameters)
184+
@common.XfailIfNoCorstone320
185+
@pytest.mark.xfail(
186+
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
187+
)
188+
def test_rsqrt_int16_u85_INT16(test_tensor: torch.Tensor):
189+
"""Test rsqrt operation with int16 quantization on U85"""
190+
pipeline = EthosU85PipelineINT[input_t1](
191+
Rsqrt(),
192+
test_tensor(),
193+
aten_op,
194+
exir_ops=[],
195+
use_to_edge_transform_and_lower=True,
196+
atol=1e-03,
197+
rtol=1e-03,
198+
run_on_fvp=True,
199+
)
200+
201+
pipeline.change_args(
202+
"quantize",
203+
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
204+
)
205+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def define_arm_tests():
2222
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
25+
"ops/test_rsqrt.py",
2526
"ops/test_slice.py",
2627
"ops/test_sigmoid.py",
2728
"ops/test_sub.py",

0 commit comments

Comments
 (0)