Skip to content

Commit 94284d7

Browse files
authored
Revert "Arm backend: Add 16A8W linear ops support and test (#13754)" (#13895)
This reverts commit f8156fb. ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent 19a54bc commit 94284d7

File tree

3 files changed

+12
-83
lines changed

3 files changed

+12
-83
lines changed

backends/arm/operators/op_transpose.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ def define_node(
4444

4545
validate_num_inputs(self.target, inputs, 2)
4646
validate_same_dtype(self.target, [inputs[0], output], ts)
47-
48-
valid_dtypes = [ts.DType.BOOL]
49-
if self.tosa_spec.support_integer():
50-
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16])
51-
if self.tosa_spec.support_float():
52-
valid_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
53-
5447
validate_valid_dtype(
5548
self.target,
5649
[inputs[0], output],
57-
valid_dtypes,
50+
[
51+
ts.DType.INT8,
52+
ts.DType.INT16,
53+
ts.DType.INT32,
54+
ts.DType.FP32,
55+
ts.DType.BOOL,
56+
ts.DType.FP16,
57+
],
5858
output.tosa_spec,
5959
)
6060

backends/arm/operators/op_view.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,10 @@ def define_node(
4141

4242
validate_num_inputs(self.target, inputs, 2)
4343
validate_same_dtype(self.target, [inputs[0], output], ts)
44-
valid_dtypes = [ts.DType.BOOL]
45-
if self.tosa_spec.support_integer():
46-
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16])
47-
if self.tosa_spec.support_float():
48-
valid_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
49-
5044
validate_valid_dtype(
5145
self.target,
5246
[inputs[0], output],
53-
valid_dtypes,
47+
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL],
5448
output.tosa_spec,
5549
)
5650

backends/arm/test/ops/test_linear.py

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
from typing import Tuple
1010

1111
import pytest
12+
1213
import torch
13-
from executorch.backends.arm.quantizer.arm_quantizer import (
14-
get_symmetric_a16w8_quantization_config,
15-
TOSAQuantizer,
16-
)
17-
from executorch.backends.arm.test import common, conftest
14+
from executorch.backends.arm.test import common
1815

1916
from executorch.backends.arm.test.tester.test_pipeline import (
2017
EthosU55PipelineINT,
@@ -23,8 +20,6 @@
2320
TosaPipelineINT,
2421
VgfPipeline,
2522
)
26-
from executorch.backends.arm.tosa_specification import TosaSpecification
27-
from executorch.backends.xnnpack.test.tester import Quantize
2823

2924
aten_op = "torch.ops.aten.linear.default"
3025

@@ -148,6 +143,7 @@ def test_linear_tosa_FP(test_data: torch.Tensor):
148143
pipeline.run()
149144

150145

146+
@pytest.mark.flaky(reruns=5) # TODO: Investigate flakyness.
151147
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
152148
def test_linear_tosa_INT(test_data: torch.Tensor):
153149
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -247,64 +243,3 @@ def test_linear_vgf_INT(test_data: torch.Tensor):
247243
per_channel_quantization=per_channel_quantization,
248244
)
249245
pipeline.run()
250-
251-
252-
def get_symmetric_a16w8_linear_quantizer(
253-
u55_config=False, per_channel_quantization=False
254-
):
255-
tosa_version = conftest.get_option("tosa_version")
256-
tosa_profiles = {
257-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
258-
}
259-
260-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
261-
quantizer.set_global(
262-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
263-
)
264-
quantizer.set_module_type(
265-
torch.nn.Linear,
266-
get_symmetric_a16w8_quantization_config(
267-
is_per_channel=per_channel_quantization
268-
),
269-
)
270-
271-
return Quantize(
272-
quantizer,
273-
get_symmetric_a16w8_quantization_config(
274-
is_per_channel=per_channel_quantization
275-
),
276-
)
277-
278-
279-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
280-
@pytest.mark.xfail(
281-
reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph"
282-
)
283-
def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
284-
"""Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
285-
test_data, out_features, has_bias, per_channel_quantization = test_data()
286-
in_features = test_data.shape[-1]
287-
288-
# Create pipeline with custom 16A8W quantization config
289-
pipeline = TosaPipelineINT[input_t1](
290-
Linear(
291-
in_features=in_features,
292-
out_features=out_features,
293-
bias=has_bias,
294-
),
295-
(test_data,),
296-
aten_op,
297-
exir_op=[],
298-
per_channel_quantization=per_channel_quantization,
299-
use_to_edge_transform_and_lower=True,
300-
tosa_extensions=["int16"],
301-
)
302-
303-
pipeline.change_args(
304-
"quantize",
305-
get_symmetric_a16w8_linear_quantizer(
306-
per_channel_quantization=per_channel_quantization
307-
),
308-
)
309-
# Run the pipeline
310-
pipeline.run()

0 commit comments

Comments
 (0)