Skip to content

Commit aeed916

Browse files
authored
Rescale sub int16 correctly
Differential Revision: D83437623 Pull Request resolved: #14650
1 parent e852066 commit aeed916

File tree

3 files changed

+118
-3
lines changed

3 files changed

+118
-3
lines changed

backends/arm/operators/op_sub.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_node(
5050
validate_valid_dtype(
5151
self.target,
5252
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
53+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
5656

@@ -59,12 +59,18 @@ def define_node(
5959
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6060
tosa_graph, inputs, node, self.tosa_spec
6161
)
62+
elif inputs[0].dtype == ts.DType.INT16:
63+
rescaled_inputs, scale_back = (
64+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
65+
tosa_graph, inputs, node, self.tosa_spec
66+
)
67+
)
6268
else:
6369
# input[0].dtype == ts.DType.INT32
6470
# Non quantized input, natively support by TOSA.SUB
6571
rescaled_inputs = inputs
6672

67-
if output.dtype == ts.DType.INT8:
73+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
6874
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6975
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7076
else:
@@ -95,6 +101,15 @@ def define_node(
95101
compute_rescale=False,
96102
tosa_spec=self.tosa_spec,
97103
) # type: ignore[possibly-undefined]
104+
elif output.dtype == ts.DType.INT16:
105+
tqutils.insert_rescale_op_to_int16(
106+
tosa_graph,
107+
sub_output,
108+
scale_back,
109+
node,
110+
compute_rescale=False,
111+
tosa_spec=self.tosa_spec,
112+
) # type: ignore[possibly-undefined]
98113

99114

100115
@register_node_visitor

backends/arm/test/ops/test_sub.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@
1010
from typing import Tuple
1111

1212
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1317

14-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1519
from executorch.backends.arm.test.tester.test_pipeline import (
1620
EthosU55PipelineINT,
1721
EthosU85PipelineINT,
1822
TosaPipelineFP,
1923
TosaPipelineINT,
2024
VgfPipeline,
2125
)
26+
from executorch.backends.arm.tosa import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2228

2329
aten_op = "torch.ops.aten.sub.Tensor"
2430
exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor"
@@ -242,3 +248,96 @@ def test_sub_tensor_vgf_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
242248
tosa_version="TOSA-1.0+INT",
243249
)
244250
pipeline.run()
251+
252+
253+
def get_symmetric_a16w8_sub_quantizer(per_channel_quantization=False):
254+
tosa_version = conftest.get_option("tosa_version")
255+
tosa_profiles = {
256+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
257+
}
258+
259+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
260+
quantizer.set_global(
261+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
262+
)
263+
264+
return Quantize(
265+
quantizer,
266+
get_symmetric_a16w8_quantization_config(
267+
is_per_channel=per_channel_quantization
268+
),
269+
)
270+
271+
272+
@common.parametrize("test_data", sub_test_data)
273+
def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1):
274+
"""Test sub operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
275+
per_channel_quantization = False
276+
277+
pipeline = TosaPipelineINT[input_t1](
278+
Sub(),
279+
test_data(),
280+
aten_op,
281+
exir_op=[],
282+
per_channel_quantization=per_channel_quantization,
283+
use_to_edge_transform_and_lower=True,
284+
tosa_extensions=["int16"],
285+
)
286+
287+
pipeline.change_args(
288+
"quantize",
289+
get_symmetric_a16w8_sub_quantizer(
290+
per_channel_quantization=per_channel_quantization
291+
),
292+
)
293+
pipeline.run()
294+
295+
296+
@common.parametrize("test_data", sub_test_data)
297+
@common.XfailIfNoCorstone300
298+
def test_sub_tensor_16a8w_u55_INT16(test_data: input_t1):
299+
"""Test sub operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
300+
per_channel_quantization = False
301+
302+
pipeline = EthosU55PipelineINT[input_t1](
303+
Sub(),
304+
test_data(),
305+
aten_op,
306+
exir_op,
307+
per_channel_quantization=per_channel_quantization,
308+
use_to_edge_transform_and_lower=True,
309+
run_on_fvp=True,
310+
)
311+
312+
pipeline.change_args(
313+
"quantize",
314+
get_symmetric_a16w8_sub_quantizer(
315+
per_channel_quantization=per_channel_quantization
316+
),
317+
)
318+
pipeline.run()
319+
320+
321+
@common.parametrize("test_data", sub_test_data)
322+
@common.XfailIfNoCorstone320
323+
def test_sub_tensor_16a8w_u85_INT16(test_data: input_t1):
324+
"""Test sub operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
325+
per_channel_quantization = False
326+
327+
pipeline = EthosU85PipelineINT[input_t1](
328+
Sub(),
329+
test_data(),
330+
aten_op,
331+
exir_op,
332+
per_channel_quantization=per_channel_quantization,
333+
use_to_edge_transform_and_lower=True,
334+
run_on_fvp=True,
335+
)
336+
337+
pipeline.change_args(
338+
"quantize",
339+
get_symmetric_a16w8_sub_quantizer(
340+
per_channel_quantization=per_channel_quantization
341+
),
342+
)
343+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def define_arm_tests():
2222
"ops/test_mul.py",
2323
"ops/test_slice.py",
2424
"ops/test_sigmoid.py",
25+
"ops/test_sub.py",
2526
"ops/test_tanh.py",
2627
"ops/test_view.py",
2728
"ops/test_cos.py",

0 commit comments

Comments
 (0)