Skip to content

Commit 8281869

Browse files
committed
Add 16A8W support and test for slice operation
Add 16A8W quantization support and test for the slice operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, and tanh operations, extending int16 support to slice operations. Changes: - Add INT16 dtype validation support in op_slice.py - Add test_slice_tensor_16a8w_tosa_INT test function - Enable test_slice.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80511095](https://our.internmc.facebook.com/intern/diff/D80511095/) [ghstack-poisoned]
1 parent 1a048c5 commit 8281869

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

backends/arm/operators/op_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def define_node(
5757
validate_valid_dtype(
5858
self.target,
5959
[inputs[0], output],
60-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
60+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
6161
output.tosa_spec,
6262
)
6363

backends/arm/test/ops/test_slice.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77

88
from typing import Tuple
99

10+
import pytest
1011
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
1116

12-
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test import common, conftest
1318

1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
@@ -18,6 +23,8 @@
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
aten_op = "torch.ops.aten.slice.Tensor"
2330
exir_op = "executorch_exir_dialects_edge__ops_aten_slice_copy"
@@ -119,3 +126,102 @@ def test_slice_tensor_vgf_INT(test_data: torch.Tensor):
119126
tosa_version="TOSA-1.0+INT",
120127
)
121128
pipeline.run()
129+
130+
131+
def get_symmetric_a16w8_slice_quantizer(per_channel_quantization=False):
132+
tosa_version = conftest.get_option("tosa_version")
133+
tosa_profiles = {
134+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
135+
}
136+
137+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
138+
quantizer.set_global(
139+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
140+
)
141+
142+
return Quantize(
143+
quantizer,
144+
get_symmetric_a16w8_quantization_config(
145+
is_per_channel=per_channel_quantization
146+
),
147+
)
148+
149+
150+
@common.parametrize("test_data", test_data_suite)
151+
def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor):
152+
"""Test slice operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
153+
per_channel_quantization = False
154+
155+
pipeline = TosaPipelineINT[input_t1](
156+
Slice(),
157+
test_data(),
158+
aten_op,
159+
exir_op=[],
160+
per_channel_quantization=per_channel_quantization,
161+
use_to_edge_transform_and_lower=True,
162+
tosa_extensions=["int16"],
163+
)
164+
165+
pipeline.change_args(
166+
"quantize",
167+
get_symmetric_a16w8_slice_quantizer(
168+
per_channel_quantization=per_channel_quantization
169+
),
170+
)
171+
pipeline.run()
172+
173+
174+
@common.parametrize("test_data", test_data_suite)
175+
@common.XfailIfNoCorstone300
176+
@pytest.mark.xfail(
177+
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
178+
)
179+
def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor):
180+
"""Test slice operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
181+
per_channel_quantization = False
182+
183+
pipeline = EthosU55PipelineINT[input_t1](
184+
Slice(),
185+
test_data(),
186+
aten_ops=[],
187+
exir_ops=[],
188+
per_channel_quantization=per_channel_quantization,
189+
use_to_edge_transform_and_lower=True,
190+
run_on_fvp=True,
191+
)
192+
193+
pipeline.change_args(
194+
"quantize",
195+
get_symmetric_a16w8_slice_quantizer(
196+
per_channel_quantization=per_channel_quantization
197+
),
198+
)
199+
pipeline.run()
200+
201+
202+
@common.parametrize("test_data", test_data_suite)
203+
@common.XfailIfNoCorstone320
204+
@pytest.mark.xfail(
205+
reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations"
206+
)
207+
def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor):
208+
"""Test slice operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
209+
per_channel_quantization = False
210+
211+
pipeline = EthosU85PipelineINT[input_t1](
212+
Slice(),
213+
test_data(),
214+
aten_ops=[],
215+
exir_ops=[],
216+
per_channel_quantization=per_channel_quantization,
217+
use_to_edge_transform_and_lower=True,
218+
run_on_fvp=True,
219+
)
220+
221+
pipeline.change_args(
222+
"quantize",
223+
get_symmetric_a16w8_slice_quantizer(
224+
per_channel_quantization=per_channel_quantization
225+
),
226+
)
227+
pipeline.run()

0 commit comments

Comments
 (0)