Skip to content

Commit a237e06

Browse files
committed
Add 16A8W support and test for cat operation
Add 16A8W quantization support and test for the cat operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, and view/transpose operations, extending int16 support to cat operations. Changes: - Add test_cat_tensor_16a8w_tosa_INT test function - Enable test_cat.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80511455](https://our.internmc.facebook.com/intern/diff/D80511455/) [ghstack-poisoned]
1 parent ab4c473 commit a237e06

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

backends/arm/test/ops/test_cat.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
12-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1318

1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
@@ -18,6 +23,8 @@
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor] # Input x
2330

@@ -151,3 +158,102 @@ def test_cat_vgf_INT(test_data: Tuple):
151158
tosa_version="TOSA-1.0+INT",
152159
)
153160
pipeline.run()
161+
162+
163+
def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False):
164+
tosa_version = conftest.get_option("tosa_version")
165+
tosa_profiles = {
166+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
167+
}
168+
169+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
170+
quantizer.set_global(
171+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
172+
)
173+
174+
return Quantize(
175+
quantizer,
176+
get_symmetric_a16w8_quantization_config(
177+
is_per_channel=per_channel_quantization
178+
),
179+
)
180+
181+
182+
@common.parametrize("test_data", Cat.test_parameters)
183+
def test_cat_16a8w_tosa_INT(test_data: Tuple):
184+
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
185+
per_channel_quantization = False
186+
187+
pipeline = TosaPipelineINT[input_t1](
188+
Cat(),
189+
test_data(),
190+
aten_op,
191+
exir_op=[],
192+
per_channel_quantization=per_channel_quantization,
193+
use_to_edge_transform_and_lower=True,
194+
tosa_extensions=["int16"],
195+
)
196+
197+
pipeline.change_args(
198+
"quantize",
199+
get_symmetric_a16w8_cat_quantizer(
200+
per_channel_quantization=per_channel_quantization
201+
),
202+
)
203+
pipeline.run()
204+
205+
206+
@common.parametrize("test_data", Cat.test_parameters)
207+
@common.XfailIfNoCorstone300
208+
@pytest.mark.xfail(
209+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
210+
)
211+
def test_cat_16a8w_u55_INT16(test_data: Tuple):
212+
"""Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
213+
per_channel_quantization = False
214+
215+
pipeline = EthosU55PipelineINT[input_t1](
216+
Cat(),
217+
test_data(),
218+
aten_op,
219+
exir_op,
220+
per_channel_quantization=per_channel_quantization,
221+
use_to_edge_transform_and_lower=True,
222+
run_on_fvp=True,
223+
)
224+
225+
pipeline.change_args(
226+
"quantize",
227+
get_symmetric_a16w8_cat_quantizer(
228+
per_channel_quantization=per_channel_quantization
229+
),
230+
)
231+
pipeline.run()
232+
233+
234+
@common.parametrize("test_data", Cat.test_parameters)
235+
@common.XfailIfNoCorstone320
236+
@pytest.mark.xfail(
237+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
238+
)
239+
def test_cat_16a8w_u85_INT16(test_data: Tuple):
240+
"""Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
241+
per_channel_quantization = False
242+
243+
pipeline = EthosU85PipelineINT[input_t1](
244+
Cat(),
245+
test_data(),
246+
aten_op,
247+
exir_op,
248+
per_channel_quantization=per_channel_quantization,
249+
use_to_edge_transform_and_lower=True,
250+
run_on_fvp=True,
251+
)
252+
253+
pipeline.change_args(
254+
"quantize",
255+
get_symmetric_a16w8_cat_quantizer(
256+
per_channel_quantization=per_channel_quantization
257+
),
258+
)
259+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_arm_tests():
1515
test_files += [
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
18+
"ops/test_cat.py",
1819
"ops/test_linear.py",
1920
"ops/test_mul.py",
2021
"ops/test_slice.py",

0 commit comments

Comments
 (0)