Skip to content

Commit 822a711

Browse files
authored
Update addmm int16 for Ethos-U85
Differential Revision: D83627934 Pull Request resolved: #14714
1 parent 54bfd72 commit 822a711

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def define_node(
7979
input1_zp = input_qparams[1].get_zp_per_tensor()
8080
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
8181
bmm_output_name = bmm_result.name
82+
elif inputs[0].dtype == ts.DType.INT16:
83+
input_qparams = get_input_qparams(node)
84+
input0_zp = input_qparams[0].get_zp_per_tensor()
85+
input1_zp = input_qparams[1].get_zp_per_tensor()
86+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48)
87+
bmm_output_name = bmm_result.name
8288
else:
8389
bmm_output_name = output.name
8490
input0_zp, input1_zp = 0, 0
@@ -118,3 +124,20 @@ def define_node(
118124
output_zp=[output_qparams.get_zp_per_tensor()],
119125
rounding_mode=RoundingMode.SINGLE_ROUND,
120126
)
127+
elif output.dtype == ts.DType.INT16:
128+
output_qparams = get_output_qparams(node)[0]
129+
final_output_scale = (
130+
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
131+
) / output_qparams.get_scale_per_tensor()
132+
133+
build_rescale(
134+
tosa_fb=tosa_graph,
135+
scale=[final_output_scale],
136+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
137+
input_node=bmm_result, # type: ignore[possibly-undefined]
138+
output_name=output.name,
139+
output_type=ts.DType.INT16,
140+
input_zp=[0],
141+
output_zp=[output_qparams.get_zp_per_tensor()],
142+
rounding_mode=RoundingMode.SINGLE_ROUND,
143+
)

backends/arm/test/ops/test_addmm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False):
213213

214214

215215
@common.parametrize("test_data", test_data_suite)
216-
@pytest.mark.xfail(
217-
reason="missing int16 addmm ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13979"
218-
)
219216
def test_addmm_16a8w_tosa_INT(test_data: input_t1):
220217
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
221218
per_channel_quantization = False
@@ -268,9 +265,6 @@ def test_addmm_16a8w_u55_INT16(test_data: input_t1):
268265

269266
@common.parametrize("test_data", test_data_suite)
270267
@common.XfailIfNoCorstone320
271-
@pytest.mark.xfail(
272-
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
273-
)
274268
def test_addmm_16a8w_u85_INT16(test_data: input_t1):
275269
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
276270
per_channel_quantization = False

0 commit comments

Comments
 (0)