Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/operators/op_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for op_transpose.py

[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
ts.DType.BOOL,
],
output.tosa_spec,
)

Expand Down
71 changes: 68 additions & 3 deletions backends/arm/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from typing import Tuple

import pytest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common, conftest

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
Expand All @@ -20,6 +23,8 @@
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.linear.default"

Expand Down Expand Up @@ -143,7 +148,6 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
pipeline.run()


@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
def test_linear_tosa_INT(test_data: torch.Tensor):
test_data, out_features, has_bias, per_channel_quantization = test_data()
Expand Down Expand Up @@ -258,3 +262,64 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
per_channel_quantization=per_channel_quantization,
)
pipeline.run()


def get_symmetric_a16w8_linear_quantizer(
u55_config=False, per_channel_quantization=False
):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)
quantizer.set_module_type(
torch.nn.Linear,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
@pytest.mark.xfail(
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
)
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
test_data, out_features, has_bias, per_channel_quantization = test_data()
in_features = test_data.shape[-1]

# Create pipeline with custom 16A8W quantization config
pipeline = TosaPipelineINT[input_t1](
Linear(
in_features=in_features,
out_features=out_features,
bias=has_bias,
),
(test_data,),
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_linear_quantizer(
per_channel_quantization=per_channel_quantization
),
)
# Run the pipeline
pipeline.run()
Loading