Skip to content

Commit b564bd0

Browse files
3l1facebook-github-bot
authored andcommitted
Rescale add int16 correctly
Differential Revision: D82906134
1 parent 7f2a593 commit b564bd0

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

backends/arm/operators/op_add.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ def define_node(
6464
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6565
tosa_graph, inputs, node, self.tosa_spec
6666
)
67+
elif inputs[0].dtype == ts.DType.INT16:
68+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_int16_to_int32_maxscale(
69+
tosa_graph, inputs, node, self.tosa_spec
70+
)
6771
else:
6872
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6973
# Non quantized input, natively support by TOSA.ADD
7074
rescaled_inputs = inputs
7175

72-
if output.dtype == ts.DType.INT8:
76+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
7377
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
7478
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
7579
else:
@@ -99,6 +103,16 @@ def define_node(
99103
compute_rescale=False,
100104
tosa_spec=self.tosa_spec,
101105
) # type: ignore[possibly-undefined]
106+
elif output.dtype == ts.DType.INT16:
107+
tqutils.insert_rescale_op_to_int16(
108+
tosa_graph,
109+
add_output,
110+
scale_back,
111+
node,
112+
compute_rescale=False,
113+
tosa_spec=self.tosa_spec,
114+
) # type: ignore[possibly-undefined]
115+
102116

103117

104118
@register_node_visitor

backends/arm/test/ops/test_add.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,6 @@ def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
276276

277277
@common.parametrize("test_data", Add.test_data)
278278
@common.XfailIfNoCorstone300
279-
@pytest.mark.xfail(
280-
reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730"
281-
)
282279
def test_add_tensor_16a8w_u55_INT16(test_data: input_t1):
283280
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
284281
per_channel_quantization = False

backends/arm/tosa/quant_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,57 @@ def insert_rescale_ops_to_int32_maxscale(
7676

7777
return [rescaled_lhs, rescaled_rhs], back_scale
7878

79+
def insert_rescale_ops_int16_to_int32_maxscale(
80+
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
81+
) -> tuple[list[Any], float]:
82+
"""For ADD and SUB with int16 inputs, we rescale to int32 using a different common scale(2*max(left scale,right scale))
83+
compared to all the other cases. We multiply the left and right scales by 1<<12 giving us extra precision
84+
for the computation without overflowing.
85+
86+
Returns a list of the rescaled nodes and the scale factor used,
87+
needed by insert_rescale_op_to_int16.
88+
"""
89+
90+
if len(inputs) > 2:
91+
raise ValueError("More than two inputs not supported")
92+
93+
tensors = inputs.copy()
94+
# Reshape tensor according to TOSA dim order
95+
for tensor in tensors:
96+
dim_order = tensor.dim_order
97+
tensor.shape = [tensor.shape[i] for i in dim_order]
98+
99+
input_qparams = get_input_qparams(node)
100+
lhs_qparams, rhs_qparams = input_qparams.values()
101+
lhs_scale = lhs_qparams.get_scale_per_tensor()
102+
rhs_scale = rhs_qparams.get_scale_per_tensor()
103+
# Common scale for the two numbers
104+
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
105+
SHIFT_INT16 = 12
106+
# We are adding two int16 numbers. If the zero point is non-null, the result will be in the range [-131070;131070], therefore we need 18 bits for the result.
107+
# We have a 32-bit accumulator, so we can shift to the left by 12 bits and not overflow. In reality, because we divide by the 2*max(lhs_scale,rhs_scale)
108+
# we are shifting to the left by 11.
109+
lhs_factor = (1 << SHIFT_INT16) * lhs_scale / max_scale_2x
110+
rhs_factor = (1 << SHIFT_INT16) * rhs_scale / max_scale_2x
111+
rescaled_lhs = build_rescale_to_int32(
112+
tosa_graph,
113+
tensors[0],
114+
lhs_qparams.get_zp_per_tensor(),
115+
lhs_factor,
116+
tosa_spec=tosa_spec,
117+
)
118+
rescaled_rhs = build_rescale_to_int32(
119+
tosa_graph,
120+
tensors[1],
121+
rhs_qparams.get_zp_per_tensor(),
122+
rhs_factor,
123+
tosa_spec=tosa_spec,
124+
)
125+
out_qparam = get_output_qparams(node)[0]
126+
out_scale = out_qparam.get_scale_per_tensor()
127+
back_scale = max_scale_2x / (out_scale * (1 << SHIFT_INT16))
128+
129+
return [rescaled_lhs, rescaled_rhs], back_scale
79130

80131
def insert_rescale_ops_to_int32(
81132
tosa_graph: Any,

0 commit comments

Comments
 (0)