Skip to content

Commit b759ae8

Browse files
pytorchbotNinja91
andauthored
Add 16A8W linear ops support and test (pytorch#13922)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#13899 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/18/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/18/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/18/orig @diff-train-skip-merge Co-authored-by: Nitin Jain <[email protected]>
1 parent 624463e commit b759ae8

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[
48+
ts.DType.INT8,
49+
ts.DType.INT16,
50+
ts.DType.INT32,
51+
ts.DType.FP32,
52+
ts.DType.BOOL,
53+
],
4854
output.tosa_spec,
4955
)
5056

backends/arm/test/ops/test_linear.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Tuple
1010

1111
import pytest
12-
1312
import torch
14-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
1720
EthosU55PipelineINT,
@@ -20,6 +23,8 @@
2023
TosaPipelineINT,
2124
VgfPipeline,
2225
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2328

2429
aten_op = "torch.ops.aten.linear.default"
2530

@@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143148
pipeline.run()
144149

145150

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147151
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148152
def test_linear_tosa_INT(test_data: torch.Tensor):
149153
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -243,3 +247,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
243247
per_channel_quantization=per_channel_quantization,
244248
)
245249
pipeline.run()
250+
251+
252+
def get_symmetric_a16w8_linear_quantizer(
253+
u55_config=False, per_channel_quantization=False
254+
):
255+
tosa_version = conftest.get_option("tosa_version")
256+
tosa_profiles = {
257+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
258+
}
259+
260+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
261+
quantizer.set_global(
262+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
263+
)
264+
quantizer.set_module_type(
265+
torch.nn.Linear,
266+
get_symmetric_a16w8_quantization_config(
267+
is_per_channel=per_channel_quantization
268+
),
269+
)
270+
271+
return Quantize(
272+
quantizer,
273+
get_symmetric_a16w8_quantization_config(
274+
is_per_channel=per_channel_quantization
275+
),
276+
)
277+
278+
279+
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
280+
@pytest.mark.xfail(
281+
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
282+
)
283+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
284+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
285+
test_data, out_features, has_bias, per_channel_quantization = test_data()
286+
in_features = test_data.shape[-1]
287+
288+
# Create pipeline with custom 16A8W quantization config
289+
pipeline = TosaPipelineINT[input_t1](
290+
Linear(
291+
in_features=in_features,
292+
out_features=out_features,
293+
bias=has_bias,
294+
),
295+
(test_data,),
296+
aten_op,
297+
exir_op=[],
298+
per_channel_quantization=per_channel_quantization,
299+
use_to_edge_transform_and_lower=True,
300+
tosa_extensions=["int16"],
301+
)
302+
303+
pipeline.change_args(
304+
"quantize",
305+
get_symmetric_a16w8_linear_quantizer(
306+
per_channel_quantization=per_channel_quantization
307+
),
308+
)
309+
# Run the pipeline
310+
pipeline.run()

0 commit comments

Comments
 (0)