Skip to content

Commit 535a094

Browse files
committed
Add 16A8W support and test for cat operation
Pull Request resolved: #13800 Add 16A8W quantization support and test for the cat operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, and view/transpose operations, extending int16 support to cat operations. Changes: - Add test_cat_tensor_16a8w_tosa_INT test function - Enable test_cat.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 304554484 @exported-using-ghexport Differential Revision: [D80511455](https://our.internmc.facebook.com/intern/diff/D80511455/)
1 parent f2f2de1 commit 535a094

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

backends/arm/test/ops/test_cat.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
12-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1318

1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
@@ -18,6 +23,8 @@
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor] # Input x
2330

@@ -151,3 +158,102 @@ def test_cat_vgf_INT(test_data: Tuple):
151158
tosa_version="TOSA-1.0+INT",
152159
)
153160
pipeline.run()
161+
162+
163+
def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False):
164+
tosa_version = conftest.get_option("tosa_version")
165+
tosa_profiles = {
166+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
167+
}
168+
169+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
170+
quantizer.set_global(
171+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
172+
)
173+
174+
return Quantize(
175+
quantizer,
176+
get_symmetric_a16w8_quantization_config(
177+
is_per_channel=per_channel_quantization
178+
),
179+
)
180+
181+
182+
@common.parametrize("test_data", Cat.test_parameters)
183+
def test_cat_16a8w_tosa_INT(test_data: Tuple):
184+
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
185+
per_channel_quantization = False
186+
187+
pipeline = TosaPipelineINT[input_t1](
188+
Cat(),
189+
test_data(),
190+
aten_op,
191+
exir_op=[],
192+
per_channel_quantization=per_channel_quantization,
193+
use_to_edge_transform_and_lower=True,
194+
tosa_extensions=["int16"],
195+
)
196+
197+
pipeline.change_args(
198+
"quantize",
199+
get_symmetric_a16w8_cat_quantizer(
200+
per_channel_quantization=per_channel_quantization
201+
),
202+
)
203+
pipeline.run()
204+
205+
206+
@common.parametrize("test_data", Cat.test_parameters)
207+
@common.XfailIfNoCorstone300
208+
@pytest.mark.xfail(
209+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
210+
)
211+
def test_cat_16a8w_u55_INT16(test_data: Tuple):
212+
"""Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
213+
per_channel_quantization = False
214+
215+
pipeline = EthosU55PipelineINT[input_t1](
216+
Cat(),
217+
test_data(),
218+
aten_op,
219+
exir_op,
220+
per_channel_quantization=per_channel_quantization,
221+
use_to_edge_transform_and_lower=True,
222+
run_on_fvp=True,
223+
)
224+
225+
pipeline.change_args(
226+
"quantize",
227+
get_symmetric_a16w8_cat_quantizer(
228+
per_channel_quantization=per_channel_quantization
229+
),
230+
)
231+
pipeline.run()
232+
233+
234+
@common.parametrize("test_data", Cat.test_parameters)
235+
@common.XfailIfNoCorstone320
236+
@pytest.mark.xfail(
237+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
238+
)
239+
def test_cat_16a8w_u85_INT16(test_data: Tuple):
240+
"""Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
241+
per_channel_quantization = False
242+
243+
pipeline = EthosU85PipelineINT[input_t1](
244+
Cat(),
245+
test_data(),
246+
aten_op,
247+
exir_op,
248+
per_channel_quantization=per_channel_quantization,
249+
use_to_edge_transform_and_lower=True,
250+
run_on_fvp=True,
251+
)
252+
253+
pipeline.change_args(
254+
"quantize",
255+
get_symmetric_a16w8_cat_quantizer(
256+
per_channel_quantization=per_channel_quantization
257+
),
258+
)
259+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_arm_tests():
1515
test_files += [
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
18+
"ops/test_cat.py",
1819
"ops/test_linear.py",
1920
"ops/test_mul.py",
2021
"ops/test_slice.py",

0 commit comments

Comments
 (0)