Skip to content

Commit fe26bfd

Browse files
pytorchbotNinja91
andauthored
Arm backend: Add 16A8W support and test for cat operation (#14217)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13800 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/13/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/13/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/12/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/13/orig @diff-train-skip-merge --------- Co-authored-by: Nitin Jain <[email protected]>
1 parent c5d9c6f commit fe26bfd

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

backends/arm/test/ops/test_cat.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
12-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1318

1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
@@ -18,6 +23,8 @@
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa.specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor] # Input x
2330

@@ -151,3 +158,105 @@ def test_cat_vgf_INT(test_data: Tuple):
151158
tosa_version="TOSA-1.0+INT",
152159
)
153160
pipeline.run()
161+
162+
163+
def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False):
164+
tosa_version = conftest.get_option("tosa_version")
165+
tosa_profiles = {
166+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
167+
}
168+
169+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
170+
quantizer.set_global(
171+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
172+
)
173+
174+
return Quantize(
175+
quantizer,
176+
get_symmetric_a16w8_quantization_config(
177+
is_per_channel=per_channel_quantization
178+
),
179+
)
180+
181+
182+
@common.parametrize("test_data", Cat.test_parameters)
183+
@pytest.mark.xfail(
184+
reason="missing int16 cat ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13978"
185+
)
186+
def test_cat_16a8w_tosa_INT(test_data: Tuple):
187+
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
188+
per_channel_quantization = False
189+
190+
pipeline = TosaPipelineINT[input_t1](
191+
Cat(),
192+
test_data(),
193+
aten_op,
194+
exir_op=[],
195+
per_channel_quantization=per_channel_quantization,
196+
use_to_edge_transform_and_lower=True,
197+
tosa_extensions=["int16"],
198+
)
199+
200+
pipeline.change_args(
201+
"quantize",
202+
get_symmetric_a16w8_cat_quantizer(
203+
per_channel_quantization=per_channel_quantization
204+
),
205+
)
206+
pipeline.run()
207+
208+
209+
@common.parametrize("test_data", Cat.test_parameters)
210+
@common.XfailIfNoCorstone300
211+
@pytest.mark.xfail(
212+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
213+
)
214+
def test_cat_16a8w_u55_INT16(test_data: Tuple):
215+
"""Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
216+
per_channel_quantization = False
217+
218+
pipeline = EthosU55PipelineINT[input_t1](
219+
Cat(),
220+
test_data(),
221+
aten_op,
222+
exir_op,
223+
per_channel_quantization=per_channel_quantization,
224+
use_to_edge_transform_and_lower=True,
225+
run_on_fvp=True,
226+
)
227+
228+
pipeline.change_args(
229+
"quantize",
230+
get_symmetric_a16w8_cat_quantizer(
231+
per_channel_quantization=per_channel_quantization
232+
),
233+
)
234+
pipeline.run()
235+
236+
237+
@common.parametrize("test_data", Cat.test_parameters)
238+
@common.XfailIfNoCorstone320
239+
@pytest.mark.xfail(
240+
reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations"
241+
)
242+
def test_cat_16a8w_u85_INT16(test_data: Tuple):
243+
"""Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
244+
per_channel_quantization = False
245+
246+
pipeline = EthosU85PipelineINT[input_t1](
247+
Cat(),
248+
test_data(),
249+
aten_op,
250+
exir_op,
251+
per_channel_quantization=per_channel_quantization,
252+
use_to_edge_transform_and_lower=True,
253+
run_on_fvp=True,
254+
)
255+
256+
pipeline.change_args(
257+
"quantize",
258+
get_symmetric_a16w8_cat_quantizer(
259+
per_channel_quantization=per_channel_quantization
260+
),
261+
)
262+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_arm_tests():
1515
test_files += [
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
18+
"ops/test_cat.py",
1819
"ops/test_linear.py",
1920
"ops/test_mul.py",
2021
"ops/test_slice.py",

0 commit comments

Comments
 (0)