Skip to content

Commit 7f95941

Browse files
pytorchbotNinja91
andauthored
Arm backend: Add 16A8W support and test for slice operation (#14215)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13798 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/11/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/11/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/10/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/11/orig @diff-train-skip-merge --------- Co-authored-by: Nitin Jain <[email protected]>
1 parent add30f6 commit 7f95941

File tree

2 files changed

+111
-2
lines changed

2 files changed

+111
-2
lines changed

backends/arm/operators/op_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def define_node(
5757
validate_valid_dtype(
5858
self.target,
5959
[inputs[0], output],
60-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
60+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
6161
output.tosa_spec,
6262
)
6363

backends/arm/test/ops/test_slice.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77

88
from typing import Tuple
99

10+
import pytest
1011
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
1116

12-
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test import common, conftest
1318

1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
@@ -18,6 +23,8 @@
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa.specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
aten_op = "torch.ops.aten.slice.Tensor"
2330
exir_op = "executorch_exir_dialects_edge__ops_aten_slice_copy"
@@ -119,3 +126,105 @@ def test_slice_tensor_vgf_INT(test_data: torch.Tensor):
119126
tosa_version="TOSA-1.0+INT",
120127
)
121128
pipeline.run()
129+
130+
131+
def get_symmetric_a16w8_slice_quantizer(per_channel_quantization=False):
132+
tosa_version = conftest.get_option("tosa_version")
133+
tosa_profiles = {
134+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
135+
}
136+
137+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
138+
quantizer.set_global(
139+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
140+
)
141+
142+
return Quantize(
143+
quantizer,
144+
get_symmetric_a16w8_quantization_config(
145+
is_per_channel=per_channel_quantization
146+
),
147+
)
148+
149+
150+
@common.parametrize("test_data", test_data_suite)
151+
@pytest.mark.xfail(
152+
reason="missing int16 slice ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13976"
153+
)
154+
def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor):
155+
"""Test slice operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
156+
per_channel_quantization = False
157+
158+
pipeline = TosaPipelineINT[input_t1](
159+
Slice(),
160+
test_data(),
161+
aten_op,
162+
exir_op=[],
163+
per_channel_quantization=per_channel_quantization,
164+
use_to_edge_transform_and_lower=True,
165+
tosa_extensions=["int16"],
166+
)
167+
168+
pipeline.change_args(
169+
"quantize",
170+
get_symmetric_a16w8_slice_quantizer(
171+
per_channel_quantization=per_channel_quantization
172+
),
173+
)
174+
pipeline.run()
175+
176+
177+
@common.parametrize("test_data", test_data_suite)
178+
@common.XfailIfNoCorstone300
179+
@pytest.mark.xfail(
180+
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
181+
)
182+
def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor):
183+
"""Test slice operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
184+
per_channel_quantization = False
185+
186+
pipeline = EthosU55PipelineINT[input_t1](
187+
Slice(),
188+
test_data(),
189+
aten_ops=[],
190+
exir_ops=[],
191+
per_channel_quantization=per_channel_quantization,
192+
use_to_edge_transform_and_lower=True,
193+
run_on_fvp=True,
194+
)
195+
196+
pipeline.change_args(
197+
"quantize",
198+
get_symmetric_a16w8_slice_quantizer(
199+
per_channel_quantization=per_channel_quantization
200+
),
201+
)
202+
pipeline.run()
203+
204+
205+
@common.parametrize("test_data", test_data_suite)
206+
@common.XfailIfNoCorstone320
207+
@pytest.mark.xfail(
208+
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
209+
)
210+
def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor):
211+
"""Test slice operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
212+
per_channel_quantization = False
213+
214+
pipeline = EthosU85PipelineINT[input_t1](
215+
Slice(),
216+
test_data(),
217+
aten_ops=[],
218+
exir_ops=[],
219+
per_channel_quantization=per_channel_quantization,
220+
use_to_edge_transform_and_lower=True,
221+
run_on_fvp=True,
222+
)
223+
224+
pipeline.change_args(
225+
"quantize",
226+
get_symmetric_a16w8_slice_quantizer(
227+
per_channel_quantization=per_channel_quantization
228+
),
229+
)
230+
pipeline.run()

0 commit comments

Comments
 (0)