Skip to content

Commit bb81136

Browse files
perdigantdesai
andauthored
Arm backend: Int16 linear support (#14258)
### Summary Adds support for a16w8 for linear when targeting a backend with +int16 extension. Fixes #13729 ### Test plan Tested through unit tests. Signed-off-by: Per Åstrand <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent d7b9010 commit bb81136

12 files changed

+301
-41
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
from .decompose_glu_pass import DecomposeGluPass # noqa
4747
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4848
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
49+
from .decompose_int16_activation_conv2d_pass import ( # noqa
50+
DecomposeConv2dWithInt16ActivationPass,
51+
)
4952
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5053
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
5154
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa

backends/arm/_passes/add_bias_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
11+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1112
from executorch.backends.transforms.utils import create_constant_placeholder
1213

1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -59,6 +60,10 @@ def call(self, graph_module):
5960
persistent_buffer=True,
6061
name=f"{node.name}_bias",
6162
)
63+
if node.args[0].meta["val"].dtype == torch.int16:
64+
bias_node.meta[TosaSpecialDtype.meta_key()] = (
65+
TosaSpecialDtype.INT48
66+
)
6267
node.update_arg(2, bias_node)
6368

6469
if modified:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeAtanPass,
4343
DecomposeAvgPool2d,
4444
DecomposeBatchNormNoStatsPass,
45+
DecomposeConv2dWithInt16ActivationPass,
4546
DecomposeCoshPass,
4647
DecomposeCosineSimilarityPass,
4748
DecomposeCumsumPass,
@@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
183184
self.add_pass(ComputeConstantOpsAOT(exported_program))
184185

185186
self.add_pass(DecomposeGroupedConv())
187+
186188
self.add_pass(ConvertExpandCopyToRepeatPass())
187189
self.add_pass(UnsqueezeBeforeRepeatPass())
188190
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
@@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
196198

197199
self.add_pass(FuseViewCopyTransform())
198200
self.add_pass(FuseConstantArgsPass(exported_program))
201+
self.add_pass(InsertTableOpsPass(exported_program))
202+
# If we have a conv2d with int16 activation split up into a convolution
203+
# and an addition, to work-around the lack of support for int48 in torch
204+
# needs to happen before AddBiasPass, but after the table ops are inserted
205+
# to be able to validate that conv2d has right dtype arguments.
206+
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
199207
self.add_pass(AddBiasPass(exported_program))
200208

201-
self.add_pass(InsertTableOpsPass(exported_program))
202209
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
203210
self.add_pass(ToTosaMemoryFormatPass(exported_program))
204211
self.add_pass(RemoveNoopPass())
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.quant_args import QuantArgs
12+
13+
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass
16+
17+
18+
class DecomposeConv2dWithInt16ActivationPass(ExportPass):
19+
"""
20+
This pass decomposes a convolution with input dtype int16 and bias
21+
into a convolution without bias followed by an addition of the bias
22+
since the TOSA op requires the bias to be int48 which is hard to represent
23+
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
24+
"""
25+
26+
def call_operator(self, op, args, kwargs, meta):
27+
if op != exir_ops.edge.aten.convolution.default:
28+
return super().call_operator(op, args, kwargs, meta)
29+
30+
tosa_spec = get_context_spec()
31+
if not tosa_spec.support_integer():
32+
return super().call_operator(op, args, kwargs, meta)
33+
34+
# return if no bias
35+
if args[2] is None:
36+
return super().call_operator(op, args, kwargs, meta)
37+
38+
if args[0].data.dtype == torch.int8:
39+
return super().call_operator(op, args, kwargs, meta)
40+
elif args[0].data.dtype == torch.int16:
41+
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
42+
"int16"
43+
):
44+
raise ValueError(
45+
"int16 activation for convolution requires TOSA int16 extension"
46+
)
47+
else:
48+
raise NotImplementedError(
49+
"Decomposition to conv+add only implemented for activation of int16 type"
50+
)
51+
52+
# convolution with bias and activation is int16
53+
# The bias is assumed to be quantized with the same quantization parameters as
54+
# as the output of the convolution
55+
bias = args[2]
56+
assert (
57+
meta.data["output_qparams"][0].dtype == bias.data.dtype
58+
), "Bias needs to have same type as quantized output type"
59+
no_bias_args = list(args)
60+
no_bias_args[2] = None
61+
# split up to convolution + bias
62+
convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta)
63+
64+
# create a copy of the meta without the qparams, to be used with the new nodes
65+
new_meta = meta.copy()
66+
new_meta.data.pop("output_qparams", None)
67+
new_meta.data.pop("input_qparams", None)
68+
69+
# reshape the tensor to the same rank as the convolution output to add the bias to the channels
70+
channel_bias = super().call_operator(
71+
exir_ops.edge.aten.view_copy.default,
72+
(bias, [1, len(bias.data), 1, 1]),
73+
{},
74+
new_meta,
75+
)
76+
77+
output_dtype = meta.data["output_qparams"][0].dtype
78+
79+
if output_dtype == torch.int16:
80+
# The conv will get the output int48 scaled to int32 in serialization step.
81+
# To be able to add the bias we need to first scale (cast?) the output to int32.
82+
# The resulting i32 sum will then need to be scaled back to the output dtype.
83+
84+
# calculate common rescale factor from convolution output and bias quantization
85+
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
86+
conv_output_scale = output_qparams.scale
87+
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
88+
bias_scale = bias_qparams.scale
89+
90+
common_scale = max(bias_scale, conv_output_scale)
91+
92+
# calculate how we can rescale bias and conv to a common scale and maximize the output range
93+
bias_rescale_factor = bias_scale / common_scale
94+
conv_rescale_factor = conv_output_scale / common_scale
95+
96+
# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
97+
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
98+
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
99+
# and then one for the sign bit.
100+
bits_left_to_shift = 14
101+
102+
# update rescale factors
103+
bias_rescale_factor *= 1 << bits_left_to_shift
104+
conv_rescale_factor *= 1 << bits_left_to_shift
105+
106+
conv_output = super().call_operator(
107+
exir_ops.backend.tosa.RESCALE.default,
108+
(convolution, torch.int32, conv_rescale_factor, 0, 0),
109+
{},
110+
new_meta,
111+
)
112+
113+
bias_rescaled = super().call_operator(
114+
exir_ops.backend.tosa.RESCALE.default,
115+
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
116+
{},
117+
new_meta,
118+
)
119+
120+
add = super().call_operator(
121+
exir_ops.edge.aten.add.Tensor,
122+
(conv_output, bias_rescaled),
123+
{},
124+
new_meta,
125+
)
126+
127+
res_rescale = super().call_operator(
128+
exir_ops.backend.tosa.RESCALE.default,
129+
(
130+
add,
131+
output_dtype,
132+
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
133+
0,
134+
0,
135+
),
136+
{},
137+
new_meta,
138+
)
139+
140+
else:
141+
raise NotImplementedError(
142+
f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}"
143+
)
144+
145+
return res_rescale

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from typing import Set, Type
99

1010
import torch
11+
1112
from executorch.backends.arm._passes.arm_pass_utils import (
1213
get_constant_placeholder_kind,
1314
get_param_tensor,
1415
is_param_node,
1516
)
17+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1618
from executorch.backends.transforms.utils import (
1719
create_constant_placeholder,
1820
delete_constant_placeholder,
@@ -47,9 +49,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4749
continue
4850
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
4951
# Ensure tensor is on CPU and contiguous
52+
53+
# ensure we don't merge any special case int48_t tensors with int32_t tensors
54+
# since int48_t tensors needs to be instantiated separately.
55+
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
5056
t_cpu = tensor.detach().cpu().contiguous()
5157
data_bytes = t_cpu.numpy().tobytes()
5258
key = (
59+
is_int48,
5360
str(t_cpu.dtype),
5461
tuple(t_cpu.shape),
5562
hashlib.sha1(data_bytes).hexdigest(),

backends/arm/operators/op_conv2d.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
)
2222
from executorch.backends.arm.operators.operator_validation_utils import (
2323
validate_num_inputs,
24+
validate_valid_dtype,
2425
)
25-
from executorch.backends.arm.tosa import TosaSpecification
2626
from executorch.backends.arm.tosa.mapping import TosaArg
2727
from executorch.backends.arm.tosa.quant_utils import build_rescale
28+
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
2829
from executorch.backends.arm.tosa.utils import tosa_shape
2930

3031

@@ -101,6 +102,32 @@ def define_node(
101102
input, weight, bias, stride, pad, dilation, _, _, group = inputs
102103
validate_num_inputs(self.target, inputs, 9)
103104

105+
valid_input_dtypes = []
106+
if self.tosa_spec.support_float():
107+
valid_input_dtypes.append(ts.DType.FP32)
108+
if self.tosa_spec.support_integer():
109+
valid_input_dtypes.append(ts.DType.INT8)
110+
111+
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
112+
"int16"
113+
):
114+
valid_input_dtypes.append(ts.DType.INT16)
115+
# Check constraints for int16 activations
116+
if inputs[0].dtype == ts.DType.INT16:
117+
validate_valid_dtype(
118+
self.target, [inputs[1]], [ts.DType.INT8], self.tosa_spec
119+
)
120+
validate_valid_dtype(
121+
self.target, [inputs[2]], [ts.DType.INT48], self.tosa_spec
122+
)
123+
124+
validate_valid_dtype(
125+
self.target,
126+
[inputs[0]],
127+
valid_input_dtypes,
128+
self.tosa_spec,
129+
)
130+
104131
# Get the attributes of convolution.
105132
attr = ts.TosaSerializerAttribute()
106133
pad_attr = [val for val in pad.special for _ in (0, 1)]
@@ -125,8 +152,8 @@ def define_node(
125152
)
126153

127154
input_zp = 0
128-
if inputs[0].dtype == ts.DType.INT8:
129-
# int8 input requires quantization information
155+
if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16):
156+
# int8 and int16 input requires quantization information
130157
input_qparams = get_input_qparams(node)
131158
input_zp = input_qparams[0].get_zp_per_tensor()
132159

@@ -137,15 +164,22 @@ def define_node(
137164
weight_zp = input_qparams[1].zp # type: ignore[assignment]
138165

139166
# The output type is int32 when input type is int8.
140-
conv2d_output_name = output.name
141-
if output.dtype == ts.DType.INT8:
167+
if inputs[0].dtype == ts.DType.INT8:
142168
conv2d_res = tosa_graph.addIntermediate(
143169
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
144170
)
145171
conv2d_output_name = conv2d_res.name
146-
acc_type = (
147-
inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32
148-
)
172+
acc_type = ts.DType.INT32
173+
elif inputs[0].dtype == ts.DType.INT16:
174+
conv2d_res = tosa_graph.addIntermediate(
175+
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
176+
)
177+
conv2d_output_name = conv2d_res.name
178+
acc_type = ts.DType.INT48
179+
else:
180+
conv2d_output_name = output.name
181+
conv2d_res = output
182+
acc_type = ts.DType.FP32
149183

150184
tosa_graph.addConst(
151185
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
@@ -235,7 +269,7 @@ def define_node(
235269

236270
# For quantized convolution, rescale the output value back to the same
237271
# integer value domain of the next op. Otherwise return float32 output.
238-
if inputs[0].dtype == ts.DType.INT8:
272+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
239273
# Get scale_factor from input, weight, and output.
240274
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
241275
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]

backends/arm/process_node.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.fx
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
15-
from executorch.backends.arm.tosa.mapping import TosaArg
15+
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
1616
from executorch.backends.arm.tosa.specification import TosaSpecification
1717
from executorch.backends.arm.tosa.utils import tosa_shape
1818
from torch._export.utils import (
@@ -112,10 +112,17 @@ def process_inputs_to_parameters(
112112
if tosa_arg.dtype == torch.float32:
113113
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
114114

115+
# Handle special case for INT48 tensors
116+
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
117+
if isinstance(special_type, TosaSpecialDtype):
118+
tosa_dtype = special_type.get_tosa_dtype()
119+
else:
120+
tosa_dtype = tosa_arg.dtype
121+
115122
parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)
116123

117124
tosa_graph.addConst(
118-
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
125+
parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name
119126
)
120127

121128

0 commit comments

Comments
 (0)