Skip to content

Commit 6307924

Browse files
authored
Enable int16 for op permute
Differential Revision: D84948536 Pull Request resolved: pytorch#15256
1 parent bbb8a13 commit 6307924

File tree

3 files changed

+106
-6
lines changed

3 files changed

+106
-6
lines changed

backends/arm/operators/op_permute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def define_node(
117117
validate_valid_dtype(
118118
self.target,
119119
[inputs[0], output],
120-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
120+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
121121
output.tosa_spec,
122122
)
123123

backends/arm/test/ops/test_permute.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from typing import Tuple
1010

1111
import torch
12-
13-
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1417

1518
from executorch.backends.arm.test.tester.test_pipeline import (
1619
EthosU55PipelineINT,
@@ -19,7 +22,8 @@
1922
TosaPipelineINT,
2023
VgfPipeline,
2124
)
22-
from torchvision.ops import Permute
25+
from executorch.backends.arm.tosa import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2327

2428
input_t1 = Tuple[torch.Tensor] # Input x
2529

@@ -42,10 +46,10 @@ class SimplePermute(torch.nn.Module):
4246
def __init__(self, dims: list[int]):
4347
super().__init__()
4448

45-
self.permute = Permute(dims=dims)
49+
self.dims = dims
4650

4751
def forward(self, x):
48-
return self.permute(x)
52+
return torch.permute(x, self.dims)
4953

5054

5155
@common.parametrize("test_data", test_data_suite)
@@ -128,3 +132,98 @@ def test_permute_vgf_INT(test_data):
128132
tosa_version="TOSA-1.0+INT",
129133
)
130134
pipeline.run()
135+
136+
137+
def get_symmetric_a16w8_permute_quantizer(
138+
u55_config=False, per_channel_quantization=False
139+
):
140+
tosa_version = conftest.get_option("tosa_version")
141+
tosa_profiles = {
142+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
143+
}
144+
145+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
146+
quantizer.set_global(
147+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
148+
)
149+
150+
return Quantize(
151+
quantizer,
152+
get_symmetric_a16w8_quantization_config(
153+
is_per_channel=per_channel_quantization
154+
),
155+
)
156+
157+
158+
@common.parametrize("test_data", test_data_suite)
159+
def test_permute_int16_tosa_INT(test_data: torch.Tensor):
160+
"""Test permute operation with int16 quantization"""
161+
test_data, dims = test_data()
162+
pipeline = TosaPipelineINT[input_t1](
163+
SimplePermute(dims=dims),
164+
(test_data,),
165+
aten_op,
166+
exir_op=[],
167+
per_channel_quantization=False,
168+
use_to_edge_transform_and_lower=True,
169+
tosa_extensions=["int16"],
170+
)
171+
172+
pipeline.change_args(
173+
"quantize",
174+
get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False),
175+
)
176+
# Run the pipeline
177+
pipeline.run()
178+
179+
180+
test_data_suite_exact = {
181+
x: test_data_suite[x] for x in test_data_suite if x != "rank_4_3"
182+
}
183+
184+
185+
@common.parametrize("test_data", test_data_suite_exact)
186+
@common.XfailIfNoCorstone300
187+
def test_permute_int16_u55_INT16(test_data: torch.Tensor):
188+
"""Test permute operation with int16 quantization on U55"""
189+
test_data, dims = test_data()
190+
pipeline = EthosU55PipelineINT[input_t1](
191+
SimplePermute(dims=dims),
192+
(test_data,),
193+
aten_op,
194+
exir_ops=[],
195+
per_channel_quantization=True,
196+
use_to_edge_transform_and_lower=True,
197+
atol=1e-02,
198+
rtol=1e-02,
199+
run_on_fvp=True,
200+
)
201+
202+
pipeline.change_args(
203+
"quantize",
204+
get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False),
205+
)
206+
pipeline.run()
207+
208+
209+
@common.parametrize("test_data", test_data_suite)
210+
@common.XfailIfNoCorstone320
211+
def test_permute_int16_u85_INT16(test_data: torch.Tensor):
212+
"""Test permute operation with int16 quantization on U85"""
213+
test_data, dims = test_data()
214+
pipeline = EthosU85PipelineINT[input_t1](
215+
SimplePermute(dims=dims),
216+
(test_data,),
217+
aten_op,
218+
exir_ops=[],
219+
use_to_edge_transform_and_lower=True,
220+
atol=1e-03,
221+
rtol=1e-03,
222+
run_on_fvp=True,
223+
)
224+
225+
pipeline.change_args(
226+
"quantize",
227+
get_symmetric_a16w8_permute_quantizer(per_channel_quantization=False),
228+
)
229+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_linear.py",
2222
"ops/test_mul.py",
23+
"ops/test_permute.py",
2324
"ops/test_slice.py",
2425
"ops/test_sigmoid.py",
2526
"ops/test_sub.py",

0 commit comments

Comments
 (0)