Skip to content

Commit 95a055a

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test (#13448)
Summary: Pull Request resolved: #13448 - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. - Checked earlier marked flaky tests no longer flaky and remove markers. Note: Not verified with tosa reference model run. Reviewed By: digantdesai Differential Revision: D80308822
1 parent 0199140 commit 95a055a

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

backends/arm/operators/op_view.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[
48+
ts.DType.INT8,
49+
ts.DType.INT16,
50+
ts.DType.INT32,
51+
ts.DType.FP32,
52+
ts.DType.BOOL,
53+
],
4854
output.tosa_spec,
4955
)
5056

backends/arm/test/ops/test_linear.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
from typing import Tuple
1010

11-
import pytest
12-
1311
import torch
14-
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1517

1618
from executorch.backends.arm.test.tester.test_pipeline import (
1719
EthosU55PipelineINT,
@@ -20,6 +22,8 @@
2022
TosaPipelineINT,
2123
VgfPipeline,
2224
)
25+
from executorch.backends.arm.tosa_specification import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2327

2428
aten_op = "torch.ops.aten.linear.default"
2529

@@ -143,7 +147,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
143147
pipeline.run()
144148

145149

146-
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
147150
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
148151
def test_linear_tosa_INT(test_data: torch.Tensor):
149152
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -258,3 +261,61 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258261
per_channel_quantization=per_channel_quantization,
259262
)
260263
pipeline.run()
264+
265+
266+
def get_symmetric_a16w8_linear_quantizer(
267+
u55_config=False, per_channel_quantization=False
268+
):
269+
tosa_version = conftest.get_option("tosa_version")
270+
tosa_profiles = {
271+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
272+
}
273+
274+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
275+
quantizer.set_global(
276+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
277+
)
278+
quantizer.set_module_type(
279+
torch.nn.Linear,
280+
get_symmetric_a16w8_quantization_config(
281+
is_per_channel=per_channel_quantization
282+
),
283+
)
284+
285+
return Quantize(
286+
quantizer,
287+
get_symmetric_a16w8_quantization_config(
288+
is_per_channel=per_channel_quantization
289+
),
290+
)
291+
292+
293+
@common.parametrize("test_data", test_data_rank1_INT, test_data_rank4_INT)
294+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
295+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
296+
test_data, out_features, has_bias, per_channel_quantization = test_data()
297+
in_features = test_data.shape[-1]
298+
299+
# Create pipeline with custom 16A8W quantization config
300+
pipeline = TosaPipelineINT[input_t1](
301+
Linear(
302+
in_features=in_features,
303+
out_features=out_features,
304+
bias=has_bias,
305+
),
306+
(test_data,),
307+
aten_op,
308+
exir_op=[],
309+
per_channel_quantization=per_channel_quantization,
310+
use_to_edge_transform_and_lower=True,
311+
tosa_extensions=["int16"],
312+
)
313+
314+
pipeline.change_args(
315+
"quantize",
316+
get_symmetric_a16w8_linear_quantizer(
317+
per_channel_quantization=per_channel_quantization
318+
),
319+
)
320+
# Run the pipeline
321+
pipeline.run()

0 commit comments

Comments
 (0)