Skip to content

Commit e725aa7

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test (#13448)
Summary: - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. Note: Not verified with tosa reference model run. Differential Revision: D80308822
1 parent 1b3749d commit e725aa7

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
4848
output.tosa_spec,
4949
)
5050

backends/arm/test/ops/test_linear.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import pytest
1212

1313
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_16a8w_quantization_config,
16+
)
1417
from executorch.backends.arm.test import common
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -258,3 +261,31 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258261
per_channel_quantization=per_channel_quantization,
259262
)
260263
pipeline.run()
264+
265+
266+
@common.parametrize("test_data", test_data_rank1_INT)
267+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
268+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
269+
test_data, out_features, has_bias, per_channel_quantization = test_data()
270+
in_features = test_data.shape[-1]
271+
272+
# Create pipeline with custom 16A8W quantization config
273+
pipeline = TosaPipelineINT[input_t1](
274+
Linear(
275+
in_features=in_features,
276+
out_features=out_features,
277+
bias=has_bias,
278+
),
279+
(test_data,),
280+
aten_op,
281+
exir_op=[],
282+
per_channel_quantization=per_channel_quantization,
283+
use_to_edge_transform_and_lower=True,
284+
quantization_config=get_16a8w_quantization_config(
285+
is_per_channel=per_channel_quantization
286+
),
287+
tosa_extensions=["int16"],
288+
)
289+
290+
# Run the pipeline
291+
pipeline.run()

backends/arm/test/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ addopts = --strict-markers
33
markers =
44
slow: Tests that take long time
55
tosa_ref_model: Tests that use TOSA reference model # Temporary!
6+
flaky: Tests that are known to be flaky

backends/arm/test/tester/test_pipeline.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def __init__(
341341
qtol: int = 1,
342342
dynamic_shapes: Optional[Tuple[Any]] = None,
343343
tosa_extensions: Optional[List[str]] = None,
344+
quantization_config: Optional[Any] = None,
344345
):
345346
if tosa_extensions is None:
346347
tosa_extensions = []
@@ -356,9 +357,11 @@ def __init__(
356357
)
357358

358359
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
359-
quantization_config = get_symmetric_quantization_config(
360-
is_per_channel=per_channel_quantization
361-
)
360+
# Use custom quantization config if provided, otherwise use default
361+
if quantization_config is None:
362+
quantization_config = get_symmetric_quantization_config(
363+
is_per_channel=per_channel_quantization
364+
)
362365
if symmetric_io_quantization:
363366
quantizer.set_io(quantization_config)
364367
quant_stage = Quantize(quantizer, quantization_config)

0 commit comments

Comments
 (0)