Skip to content

Commit cd39120

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W linear ops support and test (#13448)
Summary: - Adds linear ops test using the 16A8W config in INT16 profile. - Adds support in view ops validation for INT16 Dtype. - Validated with TOSA pipeline test. Note: Not verified with tosa reference model run. Differential Revision: D80308822
1 parent de90da2 commit cd39120

File tree

4 files changed

+45
-10
lines changed

4 files changed

+45
-10
lines changed

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
47+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
4848
output.tosa_spec,
4949
)
5050

backends/arm/test/ops/test_linear.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import pytest
1212

1313
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_16a8w_quantization_config,
16+
)
1417
from executorch.backends.arm.test import common
1518

1619
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -258,3 +261,31 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
258261
per_channel_quantization=per_channel_quantization,
259262
)
260263
pipeline.run()
264+
265+
266+
@common.parametrize("test_data", test_data_rank1_INT)
267+
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
268+
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
269+
test_data, out_features, has_bias, per_channel_quantization = test_data()
270+
in_features = test_data.shape[-1]
271+
272+
# Create pipeline with custom 16A8W quantization config
273+
pipeline = TosaPipelineINT[input_t1](
274+
Linear(
275+
in_features=in_features,
276+
out_features=out_features,
277+
bias=has_bias,
278+
),
279+
(test_data,),
280+
aten_op,
281+
exir_op=[],
282+
per_channel_quantization=per_channel_quantization,
283+
use_to_edge_transform_and_lower=True,
284+
quantization_config=get_16a8w_quantization_config(
285+
is_per_channel=per_channel_quantization
286+
),
287+
tosa_extensions=["int16"],
288+
)
289+
290+
# Run the pipeline
291+
pipeline.run()

backends/arm/test/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ addopts = --strict-markers
33
markers =
44
slow: Tests that take long time
55
tosa_ref_model: Tests that use TOSA reference model # Temporary!
6+
flaky: Tests that are known to be flaky

backends/arm/test/tester/test_pipeline.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
transform_passes: Optional[
108108
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
109109
] = None,
110+
quantization_config: Optional[Any] = None,
110111
):
111112

112113
self.tester = ArmTester(
@@ -341,6 +342,7 @@ def __init__(
341342
qtol: int = 1,
342343
dynamic_shapes: Optional[Tuple[Any]] = None,
343344
tosa_extensions: Optional[List[str]] = None,
345+
quantization_config: Optional[Any] = None,
344346
):
345347
if tosa_extensions is None:
346348
tosa_extensions = []
@@ -354,15 +356,6 @@ def __init__(
354356
compile_spec = common.get_tosa_compile_spec(
355357
tosa_profiles[tosa_version], custom_path=custom_path
356358
)
357-
358-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
359-
quantization_config = get_symmetric_quantization_config(
360-
is_per_channel=per_channel_quantization
361-
)
362-
if symmetric_io_quantization:
363-
quantizer.set_io(quantization_config)
364-
quant_stage = Quantize(quantizer, quantization_config)
365-
366359
super().__init__(
367360
module,
368361
test_data,
@@ -372,6 +365,16 @@ def __init__(
372365
use_to_edge_transform_and_lower,
373366
dynamic_shapes,
374367
)
368+
369+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
370+
# Use custom quantization config if provided, otherwise use default
371+
quantization_config = quantization_config or get_symmetric_quantization_config(
372+
is_per_channel=per_channel_quantization
373+
)
374+
if symmetric_io_quantization:
375+
quantizer.set_io(quantization_config)
376+
quant_stage = Quantize(quantizer, quantization_config)
377+
375378
self.add_stage(self.tester.quantize, quant_stage, pos=0)
376379

377380
self.add_stage_after(

0 commit comments

Comments
 (0)