Skip to content

Commit 0b5a4ab

Browse files
authored
Update linear -> conv2d int16 for Ethos
Differential Revision: D83632029 Pull Request resolved: #14763
1 parent 0ee1160 commit 0b5a4ab

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def define_node(
182182
acc_type = ts.DType.FP32
183183

184184
tosa_graph.addConst(
185-
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
185+
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
186186
)
187187
tosa_graph.addConst(
188188
[1],
189-
output.dtype,
189+
inputs[1].dtype,
190190
weight_zp,
191191
name=f"{conv2d_output_name}_weight_zp",
192192
)
@@ -269,7 +269,7 @@ def define_node(
269269

270270
# For quantized convolution, rescale the output value back to the same
271271
# integer value domain of the next op. Otherwise return float32 output.
272-
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
272+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
273273
# Get scale_factor from input, weight, and output.
274274
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
275275
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]

backends/arm/test/ops/test_linear.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from typing import Tuple
1010

11-
import pytest
12-
1311
import torch
1412
from executorch.backends.arm.quantizer.arm_quantizer import (
1513
get_symmetric_a16w8_quantization_config,
@@ -313,12 +311,8 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
313311
pipeline.run()
314312

315313

316-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
314+
@common.parametrize("test_data", test_data_all_16a8w)
317315
@common.XfailIfNoCorstone300
318-
@pytest.mark.xfail(
319-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
320-
strict=False,
321-
)
322316
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
323317
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
324318
test_data, out_features, has_bias, per_channel_quantization = test_data()
@@ -347,12 +341,8 @@ def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
347341
pipeline.run()
348342

349343

350-
@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT)
344+
@common.parametrize("test_data", test_data_all_16a8w)
351345
@common.XfailIfNoCorstone320
352-
@pytest.mark.xfail(
353-
reason="Ethos-U55 A16W8 linear: int16 matmul not yet supported; pending backend support or linear->conv1x1 lowering. See: https://github.com/pytorch/executorch/issues/13947",
354-
strict=False,
355-
)
356346
def test_linear_16a8w_u85_INT16(test_data: torch.Tensor):
357347
"""Test linear operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
358348
test_data, out_features, has_bias, per_channel_quantization = test_data()

0 commit comments

Comments
 (0)