Skip to content

Commit f7c009e

Browse files
authored
Rescale add int16 correctly (#14645)
Differential Revision: D82906134
1 parent db8d04f commit f7c009e

File tree

5 files changed

+73
-15
lines changed

5 files changed

+73
-15
lines changed

backends/arm/operators/op_add.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,18 @@ def define_node(
6464
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6565
tosa_graph, inputs, node, self.tosa_spec
6666
)
67+
elif inputs[0].dtype == ts.DType.INT16:
68+
rescaled_inputs, scale_back = (
69+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
70+
tosa_graph, inputs, node, self.tosa_spec
71+
)
72+
)
6773
else:
6874
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6975
# Non quantized input, natively support by TOSA.ADD
7076
rescaled_inputs = inputs
7177

72-
if output.dtype == ts.DType.INT8:
78+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
7379
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
7480
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7581
else:
@@ -99,6 +105,15 @@ def define_node(
99105
compute_rescale=False,
100106
tosa_spec=self.tosa_spec,
101107
) # type: ignore[possibly-undefined]
108+
elif output.dtype == ts.DType.INT16:
109+
tqutils.insert_rescale_op_to_int16(
110+
tosa_graph,
111+
add_output,
112+
scale_back,
113+
node,
114+
compute_rescale=False,
115+
tosa_spec=self.tosa_spec,
116+
) # type: ignore[possibly-undefined]
102117

103118

104119
@register_node_visitor

backends/arm/test/ops/test_add.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,6 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
276276

277277
@common.parametrize("test_data", Add.test_data)
278278
@common.XfailIfNoCorstone300
279-
@pytest.mark.xfail(
280-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
281-
)
282279
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
283280
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
284281
per_channel_quantization = False
@@ -304,9 +301,6 @@ def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
304301

305302
@common.parametrize("test_data", Add.test_data)
306303
@common.XfailIfNoCorstone320
307-
@pytest.mark.xfail(
308-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
309-
)
310304
def test_add_tensor_16a8w_u85_INT16(test_data: input_t1):
311305
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
312306
per_channel_quantization = False

backends/arm/test/ops/test_to_copy.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,15 @@ def test_to_vgf_INT(test_data: Tuple):
192192
),
193193
}
194194

195-
redundant_xfails_FP = {
195+
redundant_xfails = {
196196
"rand_fp16_fp16": "FP16 is not supported",
197197
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
198198
"rand_int16_int16": "Tracing graph with quantized input is not supported.",
199199
}
200200

201-
redundant_xfails_INT = {
202-
"rand_fp16_fp16": "FP16 is not supported",
203-
"rand_int8_int8": "Tracing graph with quantized input is not supported.",
204-
}
205-
206201

207202
@common.parametrize(
208-
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_FP
203+
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
209204
)
210205
def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):
211206
test_tensor, new_dtype = test_data()
@@ -220,7 +215,7 @@ def test_to_tosa_FP_REDUNDANT_CAST(test_data: Tuple):
220215

221216

222217
@common.parametrize(
223-
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails_INT
218+
"test_data", _TO_COPY_TEST_DATA_REDUNDANT_CAST, xfails=redundant_xfails
224219
)
225220
def test_to_tosa_INT_REDUNDANT_CAST(test_data: Tuple):
226221
test_tensor, new_dtype = test_data()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def define_arm_tests():
2525
"ops/test_tanh.py",
2626
"ops/test_view.py",
2727
"ops/test_cos.py",
28+
"ops/test_to_copy.py",
2829
]
2930

3031
# Quantization

backends/arm/tosa/quant_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,59 @@ def insert_rescale_ops_to_int32_maxscale(
7777
return [rescaled_lhs, rescaled_rhs], back_scale
7878

7979

80+
def insert_rescale_ops_int16_to_int32_maxscale(
81+
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
82+
) -> tuple[list[Any], float]:
83+
"""For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale))
84+
compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision
85+
for the computation without overflowing.
86+
87+
Returns a list of the rescaled nodes and the scale factor used,
88+
needed by insert_rescale_op_to_int16.
89+
"""
90+
91+
if len(inputs) > 2:
92+
raise ValueError("More than two inputs not supported")
93+
94+
tensors = inputs.copy()
95+
# Reshape tensor according to TOSA dim order
96+
for tensor in tensors:
97+
dim_order = tensor.dim_order
98+
tensor.shape = [tensor.shape[i] for i in dim_order]
99+
100+
input_qparams = get_input_qparams(node)
101+
lhs_qparams, rhs_qparams = input_qparams.values()
102+
lhs_scale = lhs_qparams.get_scale_per_tensor()
103+
rhs_scale = rhs_qparams.get_scale_per_tensor()
104+
# Common scale for the two numbers
105+
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
106+
SHIFT_INT16 = 12
107+
# We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result.
108+
# We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale)
109+
# we are shifting to the left by 11.
110+
lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x
111+
rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x
112+
rescaled_lhs = build_rescale_to_int32(
113+
tosa_graph,
114+
tensors[0],
115+
lhs_qparams.get_zp_per_tensor(),
116+
lhs_factor,
117+
tosa_spec=tosa_spec,
118+
)
119+
rescaled_rhs = build_rescale_to_int32(
120+
tosa_graph,
121+
tensors[1],
122+
rhs_qparams.get_zp_per_tensor(),
123+
rhs_factor,
124+
tosa_spec=tosa_spec,
125+
)
126+
out_qparam = get_output_qparams(node)[0]
127+
out_scale = out_qparam.get_scale_per_tensor()
128+
back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16))
129+
130+
return [rescaled_lhs, rescaled_rhs], back_scale
131+
132+
80133
def insert_rescale_ops_to_int32(
81134
tosa_graph: Any,
82135
inputs: list[TosaArg],

0 commit comments

Comments
 (0)