Skip to content

Commit 6dc9cda

Browse files
committed
Arm backend: Decompose conv2d with 16 bit activation
Support quantization to 16a8w. Since the resulting TOSA operator needs to have the bias in int48 which isn't avaiable as a type in torch, the conv2d needs to be decomposed into a conv + add, where the conv result is scaled down to 32 bit before the addition of the bias is done. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ib8cae694035796374a55a9909e501596e983abf5
1 parent 6c22a86 commit 6dc9cda

File tree

4 files changed

+197
-23
lines changed

4 files changed

+197
-23
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
from .decompose_glu_pass import DecomposeGluPass # noqa
4747
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4848
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
49+
from .decompose_int16_activation_conv2d_pass import ( # noqa
50+
DecomposeConv2dWithInt16ActivationPass,
51+
)
4952
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5053
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
5154
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
DecomposeAtanPass,
4040
DecomposeAvgPool2d,
4141
DecomposeBatchNormNoStatsPass,
42+
DecomposeConv2dWithInt16ActivationPass,
4243
DecomposeCoshPass,
4344
DecomposeCosineSimilarityPass,
4445
DecomposeCumsumPass,
@@ -154,6 +155,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
154155
self.add_pass(ComputeConstantOpsAOT(exported_program))
155156

156157
self.add_pass(DecomposeGroupedConv())
158+
157159
self.add_pass(ConvertExpandCopyToRepeatPass())
158160
self.add_pass(UnsqueezeBeforeRepeatPass())
159161
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
@@ -167,9 +169,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
167169

168170
self.add_pass(FuseViewCopyTransform())
169171
self.add_pass(FuseConstantArgsPass(exported_program))
172+
self.add_pass(InsertTableOpsPass(exported_program))
173+
# If we have a conv2d with int16 activation split up into a convolution
174+
# and an addition, to work-around the lack of support for int48 in torch
175+
# needs to happen before AddBiasPass, but after the table ops are inserted
176+
# to be able to validate that conv2d has right dtype arguments.
177+
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
170178
self.add_pass(AddBiasPass(exported_program))
171179

172-
self.add_pass(InsertTableOpsPass(exported_program))
173180
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
174181
self.add_pass(ToTosaMemoryFormatPass(exported_program))
175182
self.add_pass(RemoveNoopPass())
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.quant_args import QuantArgs
12+
13+
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass
16+
17+
18+
class DecomposeConv2dWithInt16ActivationPass(ExportPass):
19+
"""
20+
This pass decomposes a convolution with input dtype int16 and bias
21+
into a convolution without bias followed by an addition of the bias
22+
since the TOSA op requires the bias to be int48 which is hard to represent
23+
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
24+
"""
25+
26+
def call_operator(self, op, args, kwargs, meta):
27+
if op != exir_ops.edge.aten.convolution.default:
28+
return super().call_operator(op, args, kwargs, meta)
29+
30+
tosa_spec = get_context_spec()
31+
if not tosa_spec.support_integer():
32+
return super().call_operator(op, args, kwargs, meta)
33+
34+
# return if no bias
35+
if args[2] is None:
36+
return super().call_operator(op, args, kwargs, meta)
37+
38+
if args[0].data.dtype == torch.int8:
39+
return super().call_operator(op, args, kwargs, meta)
40+
elif args[0].data.dtype == torch.int16:
41+
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
42+
"int16"
43+
):
44+
raise ValueError(
45+
"int16 activation for convolution requires TOSA int16 extension"
46+
)
47+
else:
48+
raise NotImplementedError(
49+
"Decomposition to conv+add only implemented for activation of int16 type"
50+
)
51+
52+
# convolution with bias and activation is int16
53+
# The bias is assumed to be quantized with the same quantization parameters as
54+
# as the output of the convolution
55+
bias = args[2]
56+
assert (
57+
meta.data["output_qparams"][0].dtype == bias.data.dtype
58+
), "Bias needs to have same type as quantized output type"
59+
no_bias_args = list(args)
60+
no_bias_args[2] = None
61+
# split up to convolution + bias
62+
convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta)
63+
64+
# create a copy of the meta without the qparams, to be used with the new nodes
65+
new_meta = meta.copy()
66+
new_meta.data.pop("output_qparams", None)
67+
new_meta.data.pop("input_qparams", None)
68+
69+
# reshape the tensor to the same rank as the convolution output to add the bias to the channels
70+
channel_bias = super().call_operator(
71+
exir_ops.edge.aten.view_copy.default,
72+
(bias, [1, len(bias.data), 1, 1]),
73+
{},
74+
new_meta,
75+
)
76+
77+
output_dtype = meta.data["output_qparams"][0].dtype
78+
79+
if output_dtype == torch.int16:
80+
# The conv will get the output int48 scaled to int32 in serialization step.
81+
# To be able to add the bias we need to first scale (cast?) the output to int32.
82+
# The resulting i32 sum will then need to be scaled back to the output dtype.
83+
84+
# calculate common rescale factor from convolution output and bias quantization
85+
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
86+
conv_output_scale = output_qparams.scale
87+
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
88+
bias_scale = bias_qparams.scale
89+
90+
common_scale = max(bias_scale, conv_output_scale)
91+
92+
# calculate how we can rescale bias and conv to a common scale and maximize the output range
93+
bias_rescale_factor = bias_scale / common_scale
94+
conv_rescale_factor = conv_output_scale / common_scale
95+
96+
# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
97+
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
98+
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
99+
# and then one for the sign bit.
100+
bits_left_to_shift = 14
101+
102+
# update rescale factors
103+
bias_rescale_factor *= 1 << bits_left_to_shift
104+
conv_rescale_factor *= 1 << bits_left_to_shift
105+
106+
conv_output = super().call_operator(
107+
exir_ops.backend.tosa.RESCALE.default,
108+
(convolution, torch.int32, conv_rescale_factor, 0, 0),
109+
{},
110+
new_meta,
111+
)
112+
113+
bias_rescaled = super().call_operator(
114+
exir_ops.backend.tosa.RESCALE.default,
115+
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
116+
{},
117+
new_meta,
118+
)
119+
120+
add = super().call_operator(
121+
exir_ops.edge.aten.add.Tensor,
122+
(conv_output, bias_rescaled),
123+
{},
124+
new_meta,
125+
)
126+
127+
res_rescale = super().call_operator(
128+
exir_ops.backend.tosa.RESCALE.default,
129+
(
130+
add,
131+
output_dtype,
132+
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
133+
0,
134+
0,
135+
),
136+
{},
137+
new_meta,
138+
)
139+
140+
else:
141+
raise NotImplementedError(
142+
f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}"
143+
)
144+
145+
return res_rescale

backends/arm/quantizer/quantization_config.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,29 +89,48 @@ def _derive_qparams_fn(
8989
torch.ops.aten.linear.default,
9090
torch.ops.aten.conv2d.padding,
9191
]:
92-
input_act = node.args[0]
93-
weight = node.args[1]
94-
# If the weights are quantized per_tensor, do the same with bias
95-
qscheme = (
96-
torch.per_tensor_symmetric
97-
if self.weight is None
98-
else self.weight.qscheme
99-
)
100-
ch_axis = None
101-
if self.weight is not None:
102-
if qscheme == torch.per_channel_symmetric:
103-
ch_axis = self.weight.ch_axis
92+
if self.input_activation is None or self.weight is None:
93+
raise ValueError(
94+
"Input activation and weight QuantizationConfig must be specified."
95+
)
96+
if self.input_activation.dtype == self.weight.dtype == torch.int8:
97+
# This is the default int8 quantization which uses the derived quantization
98+
# calculated from the activation and weight scale
99+
input_act = node.args[0]
100+
weight = node.args[1]
104101

105-
quantization_spec = DerivedQuantizationSpec(
106-
derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item]
107-
derive_qparams_fn=_derive_qparams_fn,
108-
dtype=torch.int32,
109-
quant_min=torch.iinfo(torch.int32).min,
110-
quant_max=torch.iinfo(torch.int32).max - 1,
111-
qscheme=qscheme,
112-
ch_axis=ch_axis,
113-
)
114-
return quantization_spec # type: ignore[return-value]
102+
# If the weights are quantized per_tensor, do the same with bias
103+
qscheme = (
104+
torch.per_tensor_symmetric
105+
if self.weight is None
106+
else self.weight.qscheme
107+
)
108+
ch_axis = None
109+
if self.weight is not None:
110+
if qscheme == torch.per_channel_symmetric:
111+
ch_axis = self.weight.ch_axis
112+
113+
quantization_spec = DerivedQuantizationSpec(
114+
derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item]
115+
derive_qparams_fn=_derive_qparams_fn,
116+
dtype=torch.int32,
117+
quant_min=torch.iinfo(torch.int32).min,
118+
quant_max=torch.iinfo(torch.int32).max - 1,
119+
qscheme=qscheme,
120+
ch_axis=ch_axis,
121+
)
122+
return quantization_spec # type: ignore[return-value]
123+
elif (
124+
self.input_activation.dtype == torch.int16
125+
and self.weight.dtype == torch.int8
126+
):
127+
# In case the activation is quantized to int16, the bias needs to be
128+
# added after the convolution, so use the output quantization for this case.
129+
return self.output_activation
130+
else:
131+
raise NotImplementedError(
132+
f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented"
133+
)
115134

116135
if self.bias is None:
117136
return None

0 commit comments

Comments
 (0)