Skip to content

Commit 1656504

Browse files
committed
Arm backend: Addressing Linear operator issues in int16x8
Signed-off-by: Saoirse Stewart <[email protected]>
1 parent b24c39a commit 1656504

File tree

4 files changed

+44
-54
lines changed

4 files changed

+44
-54
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,8 @@ def call_operator(self, op, args, kwargs, meta):
5050
)
5151

5252
# convolution with bias and activation is int16
53-
# The bias is assumed to be quantized with the same quantization parameters as
54-
# as the output of the convolution
5553
bias = args[2]
56-
assert (
57-
meta.data["output_qparams"][0].dtype == bias.data.dtype
58-
), "Bias needs to have same type as quantized output type"
54+
5955
no_bias_args = list(args)
6056
no_bias_args[2] = None
6157
# split up to convolution + bias
@@ -80,46 +76,30 @@ def call_operator(self, op, args, kwargs, meta):
8076
# The conv will get the output int48 scaled to int32 in serialization step.
8177
# To be able to add the bias we need to first scale (cast?) the output to int32.
8278
# The resulting i32 sum will then need to be scaled back to the output dtype.
83-
84-
# calculate common rescale factor from convolution output and bias quantization
8579
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
8680
conv_output_scale = output_qparams.scale
87-
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
88-
bias_scale = bias_qparams.scale
8981

90-
common_scale = max(bias_scale, conv_output_scale)
91-
92-
# calculate how we can rescale bias and conv to a common scale and maximize the output range
93-
bias_rescale_factor = bias_scale / common_scale
94-
conv_rescale_factor = conv_output_scale / common_scale
82+
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
83+
per_channel_quant = bias_qparams.per_channel
9584

96-
# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
97-
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
98-
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
99-
# and then one for the sign bit.
100-
bits_left_to_shift = 14
85+
if per_channel_quant:
86+
bias_scale = bias_qparams.get_scale_per_channel()
87+
else:
88+
bias_scale = [bias_qparams.get_scale_per_tensor()]
10189

102-
# update rescale factors
103-
bias_rescale_factor *= 1 << bits_left_to_shift
104-
conv_rescale_factor *= 1 << bits_left_to_shift
90+
conv_rescale_factors = [1.0] * len(bias_scale)
91+
final_output_scale = [b / conv_output_scale for b in bias_scale]
10592

10693
conv_output = super().call_operator(
10794
exir_ops.backend.tosa.RESCALE.default,
108-
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
109-
{},
110-
new_meta,
111-
)
112-
113-
bias_rescaled = super().call_operator(
114-
exir_ops.backend.tosa.RESCALE.default,
115-
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
95+
(convolution, torch.int32, conv_rescale_factors, 0, 0),
11696
{},
11797
new_meta,
11898
)
11999

120100
add = super().call_operator(
121101
exir_ops.edge.aten.add.Tensor,
122-
(conv_output, bias_rescaled),
102+
(conv_output, channel_bias),
123103
{},
124104
new_meta,
125105
)
@@ -129,7 +109,7 @@ def call_operator(self, op, args, kwargs, meta):
129109
(
130110
add,
131111
output_dtype,
132-
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
112+
final_output_scale,
133113
0,
134114
0,
135115
),

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
237237
pad[3],
238238
dilation[1],
239239
)
240-
241-
if bias is None:
240+
has_bias = bias is not None
241+
if not has_bias:
242242
bias = self._add_bias(graph_module, node, weight)
243243

244244
if self._is_depthwise_conv2d(node):
@@ -278,14 +278,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
278278
if (
279279
tosa_node_fake_tensor.dtype == torch.int32
280280
and input_fake_tensor.dtype == torch.int8
281-
) or (
282-
tosa_node_fake_tensor.dtype == torch.int32
283-
and input_fake_tensor.dtype == torch.int16
284281
):
285282
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286283
node.replace_all_uses_with(output_rescale)
287-
if input_fake_tensor.dtype == torch.int16:
288-
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
284+
elif (
285+
tosa_node_fake_tensor.dtype == torch.int32
286+
and input_fake_tensor.dtype == torch.int16
287+
):
288+
has_bias = len(node.meta["input_qparams"]) > 2
289+
if not has_bias:
290+
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
291+
node.replace_all_uses_with(output_rescale)
292+
else:
293+
node.replace_all_uses_with(tosa_op)
294+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289295
else:
290296
node.replace_all_uses_with(tosa_op)
291297

backends/arm/quantizer/quantization_config.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,12 @@ def _derive_qparams_fn(
183183
raise ValueError(
184184
"Input activation and weight QuantizationConfig must be specified."
185185
)
186-
if self.input_activation.dtype == self.weight.dtype == torch.int8:
187-
# This is the default int8 quantization which uses the derived quantization
188-
# calculated from the activation and weight scale
186+
187+
if (self.input_activation.dtype == self.weight.dtype == torch.int8) or (
188+
self.input_activation.dtype == torch.int16
189+
and self.weight.dtype == torch.int8
190+
):
191+
189192
input_act = node.args[0]
190193
weight = node.args[1]
191194

@@ -210,13 +213,6 @@ def _derive_qparams_fn(
210213
ch_axis=ch_axis,
211214
)
212215
return quantization_spec # type: ignore[return-value]
213-
elif (
214-
self.input_activation.dtype == torch.int16
215-
and self.weight.dtype == torch.int8
216-
):
217-
# In case the activation is quantized to int16, the bias needs to be
218-
# added after the convolution, so use the output quantization for this case.
219-
return self.output_activation
220216
else:
221217
raise NotImplementedError(
222218
f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented"

backends/arm/test/ops/test_linear.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,6 @@ def get_symmetric_a16w8_linear_quantizer(
274274

275275

276276
test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT
277-
# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377
278-
for k in list(test_data_all_16a8w.keys()):
279-
if "large_rand" in k:
280-
test_data_all_16a8w.pop(k)
281277

282278

283279
@common.parametrize("test_data", test_data_all_16a8w)
@@ -311,7 +307,19 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor):
311307
pipeline.run()
312308

313309

314-
@common.parametrize("test_data", test_data_all_16a8w)
310+
x_fails = {}
311+
for test_name in [
312+
"model_linear_rank4_zeros",
313+
"model_linear_rank4_negative_ones",
314+
"model_linear_rank4_negative_large_rand",
315+
]:
316+
for set_per_chan in ["True", "False"]:
317+
x_fails[test_name + ",per_channel_quant={}".format(set_per_chan)] = (
318+
"MLETORCH-1452: AssertionError: Output 0 does not match reference output."
319+
)
320+
321+
322+
@common.parametrize("test_data", test_data_all_16a8w, x_fails)
315323
@common.XfailIfNoCorstone300
316324
def test_linear_16a8w_u55_INT16(test_data: torch.Tensor):
317325
"""Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""

0 commit comments

Comments
 (0)