Skip to content

Commit 4cb71a0

Browse files
Arm backend: Add missing bias in pass (#11847)
Adds a pass which adds missing bias for convolution. Removes handling of missing bias in conv2d-visitor. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 121714a commit 4cb71a0

File tree

4 files changed

+67
-37
lines changed

4 files changed

+67
-37
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66

77
from . import arm_pass_utils # noqa
8+
from .arm_pass import ArmPass # noqa # usort: skip
9+
from .add_bias_pass import AddBiasPass # noqa
810
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
911
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10-
from .arm_pass import ArmPass # noqa
1112
from .broadcast_args_pass import BroadcastArgsPass # noqa
1213
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1314
from .cast_to_int32_pass import CastToInt32Pass # noqa
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
9+
from executorch.backends.transforms.utils import create_constant_placeholder
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import PassResult
13+
from torch.export.graph_signature import InputKind
14+
15+
16+
class AddBiasPass(ArmPass):
17+
"""TOSA requires convolution nodes to have a bias input.
18+
This pass adds a bias input to convolution nodes that do not have one.
19+
The bias is set to zero.
20+
"""
21+
22+
targeted_ops = (exir_ops.edge.aten.convolution.default,)
23+
24+
def call(self, graph_module):
25+
modified = False
26+
for node in graph_module.graph.nodes:
27+
if node.op != "call_function":
28+
continue
29+
if node.target not in self.targeted_ops:
30+
continue
31+
32+
if len(node.all_input_nodes) < 3:
33+
modified = True
34+
# bias is missing
35+
weight_node = node.all_input_nodes[1]
36+
output_channels = get_first_fake_tensor(weight_node).shape[0]
37+
# add a node containging zeros
38+
# if quantized, use int32, otherwise use float32
39+
if (
40+
"output_qparams" in node.meta
41+
and len(node.meta["output_qparams"]) > 0
42+
):
43+
bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32)
44+
else:
45+
bias_data = torch.zeros(
46+
size=(output_channels,), dtype=torch.float32
47+
)
48+
49+
with graph_module.graph.inserting_after(weight_node):
50+
bias_node = create_constant_placeholder(
51+
self.exported_program,
52+
graph=graph_module.graph,
53+
kind=InputKind.PARAMETER,
54+
data=bias_data,
55+
persistent_buffer=True,
56+
name=f"{node.name}_bias",
57+
)
58+
node.update_arg(2, bias_node)
59+
60+
if modified:
61+
graph_module = super().call(graph_module).graph_module
62+
return PassResult(graph_module, modified)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-unsafe
99
from executorch.backends.arm._passes import (
10+
AddBiasPass,
1011
AnnotateChannelsLastDimOrder,
1112
AnnotateDecomposedMatmulPass,
1213
BroadcastArgsPass,
@@ -134,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
134135

135136
self.add_pass(FuseViewCopyTransform())
136137
self.add_pass(FuseConstantArgsPass(exported_program))
138+
self.add_pass(AddBiasPass(exported_program))
137139

138140
self.add_pass(InsertTableOpsPass(exported_program))
139141
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
@@ -194,6 +196,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
194196

195197
self.add_pass(FuseViewCopyTransform())
196198
self.add_pass(FuseConstantArgsPass(exported_program))
199+
self.add_pass(AddBiasPass(exported_program))
197200
self.add_pass(InsertTableOpsPass(exported_program))
198201
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
199202
self.add_pass(AnnotateChannelsLastDimOrder())

backends/arm/operators/op_conv2d.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,6 @@ def define_node(
109109
local_bound=False,
110110
)
111111

112-
# Non-bias case.
113-
if len(node.all_input_nodes) == 2:
114-
# Create a zero bias tensor if not presented
115-
out_channels = weight.shape[0]
116-
bias_name = "bias" + node.name.split("default", 1)[1]
117-
bias_type = output.dtype
118-
if output.dtype == ts.DType.INT8:
119-
# Conv is quantized to int8, but the TOSA operator has
120-
# output type int32, and the bias must be the same type
121-
# as the TOSA output type
122-
bias_type = ts.DType.INT32
123-
bias = tosa_graph.addConst(
124-
[out_channels],
125-
bias_type,
126-
[0] * out_channels,
127-
name=bias_name,
128-
)
129-
130112
# The output type is int32 when input type is int8.
131113
conv2d_output_name = output.name
132114
if output.dtype == ts.DType.INT8:
@@ -313,24 +295,6 @@ def define_node(
313295
name=f"{conv2d_output_name}_weight_zp",
314296
)
315297

316-
# Non-bias case.
317-
if len(node.all_input_nodes) == 2:
318-
# Create a zero bias tensor if not presented
319-
out_channels = weight.shape[0]
320-
bias_name = f"{conv2d_output_name}_bias"
321-
bias_type = output.dtype
322-
if output.dtype == ts.DType.INT8:
323-
# Conv is quantized to int8, but the TOSA operator has
324-
# output type int32, and the bias must be the same type
325-
# as the TOSA output type
326-
bias_type = ts.DType.INT32
327-
bias = tosa_graph.addConst(
328-
[out_channels],
329-
bias_type,
330-
[0] * out_channels,
331-
name=bias_name,
332-
)
333-
334298
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
335299
in_channels = input.shape[1]
336300
out_channels = weight.shape[0]

0 commit comments

Comments
 (0)