Skip to content

Commit 5d268c7

Browse files
SaoirseARMoscarandersson8218zingo
authored
Arm backend: Addressing Linear operator issues in int16x8 (#15370)
### Summary - Update bias to use the derived quantization spec to int32. - Fixes issue in the Conv2d where the bias should be added first and then rescaled back to int16. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Saoirse Stewart <[email protected]> Co-authored-by: Oscar Andersson <[email protected]> Co-authored-by: Zingo Andersen <[email protected]>
1 parent 5cf193a commit 5d268c7

File tree

4 files changed

+44
-54
lines changed

4 files changed

+44
-54
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,8 @@ def call_operator(self, op, args, kwargs, meta):
4949
)
5050

5151
# convolution with bias and activation is int16
52-
# The bias is assumed to be quantized with the same quantization parameters as
53-
# as the output of the convolution
5452
bias = args[2]
55-
assert (
56-
meta.data["output_qparams"][0].dtype == bias.data.dtype
57-
), "Bias needs to have same type as quantized output type"
53+
5854
no_bias_args = list(args)
5955
no_bias_args[2] = None
6056
# split up to convolution + bias
@@ -79,46 +75,30 @@ def call_operator(self, op, args, kwargs, meta):
7975
# The conv will get the output int48 scaled to int32 in serialization step.
8076
# To be able to add the bias we need to first scale (cast?) the output to int32.
8177
# The resulting i32 sum will then need to be scaled back to the output dtype.
82-
83-
# calculate common rescale factor from convolution output and bias quantization
8478
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
8579
conv_output_scale = output_qparams.scale
86-
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
87-
bias_scale = bias_qparams.scale
8880

89-
common_scale = max(bias_scale, conv_output_scale)
90-
91-
# calculate how we can rescale bias and conv to a common scale and maximize the output range
92-
bias_rescale_factor = bias_scale / common_scale
93-
conv_rescale_factor = conv_output_scale / common_scale
81+
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
82+
per_channel_quant = bias_qparams.per_channel
9483

95-
# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
96-
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
97-
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
98-
# and then one for the sign bit.
99-
bits_left_to_shift = 14
84+
if per_channel_quant:
85+
bias_scale = bias_qparams.get_scale_per_channel()
86+
else:
87+
bias_scale = [bias_qparams.get_scale_per_tensor()]
10088

101-
# update rescale factors
102-
bias_rescale_factor *= 1 << bits_left_to_shift
103-
conv_rescale_factor *= 1 << bits_left_to_shift
89+
conv_rescale_factors = [1.0] * len(bias_scale)
90+
final_output_scale = [b / conv_output_scale for b in bias_scale]
10491

10592
conv_output = super().call_operator(
10693
exir_ops.backend.tosa.RESCALE.default,
107-
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
108-
{},
109-
new_meta,
110-
)
111-
112-
bias_rescaled = super().call_operator(
113-
exir_ops.backend.tosa.RESCALE.default,
114-
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
94+
(convolution, torch.int32, conv_rescale_factors, 0, 0),
11595
{},
11696
new_meta,
11797
)
11898

11999
add = super().call_operator(
120100
exir_ops.edge.aten.add.Tensor,
121-
(conv_output, bias_rescaled),
101+
(conv_output, channel_bias),
122102
{},
123103
new_meta,
124104
)
@@ -128,7 +108,7 @@ def call_operator(self, op, args, kwargs, meta):
128108
(
129109
add,
130110
output_dtype,
131-
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
111+
final_output_scale,
132112
0,
133113
0,
134114
),

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
237237
pad[3],
238238
dilation[1],
239239
)
240-
241-
if bias is None:
240+
has_bias = bias is not None
241+
if not has_bias:
242242
bias = self._add_bias(graph_module, node, weight)
243243

244244
if self._is_depthwise_conv2d(node):
@@ -278,14 +278,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
278278
if (
279279
tosa_node_fake_tensor.dtype == torch.int32
280280
and input_fake_tensor.dtype == torch.int8
281-
) or (
282-
tosa_node_fake_tensor.dtype == torch.int32
283-
and input_fake_tensor.dtype == torch.int16
284281
):
285282
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286283
node.replace_all_uses_with(output_rescale)
287-
if input_fake_tensor.dtype == torch.int16:
288-
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
284+
elif (
285+
tosa_node_fake_tensor.dtype == torch.int32
286+
and input_fake_tensor.dtype == torch.int16
287+
):
288+
has_bias = len(node.meta["input_qparams"]) > 2
289+
if not has_bias:
290+
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
291+
node.replace_all_uses_with(output_rescale)
292+
else:
293+
node.replace_all_uses_with(tosa_op)
294+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289295
else:
290296
node.replace_all_uses_with(tosa_op)
291297

backends/arm/quantizer/quantization_config.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,12 @@ def _derive_qparams_fn(
182182
raise ValueError(
183183
"Input activation and weight QuantizationConfig must be specified."
184184
)
185-
if self.input_activation.dtype == self.weight.dtype == torch.int8:
186-
# This is the default int8 quantization which uses the derived quantization
187-
# calculated from the activation and weight scale
185+
186+
if (self.input_activation.dtype == self.weight.dtype == torch.int8) or (
187+
self.input_activation.dtype == torch.int16
188+
and self.weight.dtype == torch.int8
189+
):
190+
188191
input_act = node.args[0]
189192
weight = node.args[1]
190193

@@ -209,13 +212,6 @@ def _derive_qparams_fn(
209212
ch_axis=ch_axis,
210213
)
211214
return quantization_spec # type: ignore[return-value]
212-
elif (
213-
self.input_activation.dtype == torch.int16
214-
and self.weight.dtype == torch.int8
215-
):
216-
# In case the activation is quantized to int16, the bias needs to be
217-
# added after the convolution, so use the output quantization for this case.
218-
return self.output_activation
219215
else:
220216
raise NotImplementedError(
221217
f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented"

backends/arm/test/ops/test_linear.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,6 @@ def get_symmetric_a16w8_linear_quantizer(
274274

275275

276276
test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT
277-
# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377
278-
for k in list(test_data_all_16a8w.keys()):
279-
if "large_rand" in k:
280-
test_data_all_16a8w.pop(k)
281277

282278

283279
@common.parametrize("test_data", test_data_all_16a8w)
@@ -311,7 +307,19 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
311307
pipeline.run()
312308

313309

314-
@common.parametrize("test_data", test_data_all_16a8w)
310+
x_fails = {}
311+
for test_name in [
312+
"model_linear_rank4_zeros",
313+
"model_linear_rank4_negative_ones",
314+
"model_linear_rank4_negative_large_rand",
315+
]:
316+
for set_per_chan in ["True", "False"]:
317+
x_fails[test_name + ",per_channel_quant={}".format(set_per_chan)] = (
318+
"MLETORCH-1452: AssertionError: Output 0 does not match reference output."
319+
)
320+
321+
322+
@common.parametrize("test_data", test_data_all_16a8w, x_fails)
315323
@common.XfailIfNoCorstone300
316324
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
317325
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""

0 commit comments

Comments
 (0)