Skip to content

Commit 7ee9bc2

Browse files
committed
Add pass for fusing batchnorm into conv
The pass differs from existing fuse passes since they use the get_attr node which is not supported by ArmBackend. Instead, we update the existing parameters. Also adds tests. Change-Id: Iad6d70e632191d74d96df62b1837d37fe60e7d3a
1 parent 164a0d2 commit 7ee9bc2

File tree

4 files changed

+309
-10
lines changed

4 files changed

+309
-10
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
QuantizeFullArgument,
4343
RetraceFoldedDtypesPass,
4444
)
45+
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
4546
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
4647
FuseQuantizedActivationPass,
4748
)
@@ -126,6 +127,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
126127
self.add_pass(ConvertMeanDimToAveragePoolPass())
127128
self.add_pass(DecomposeDivPass())
128129
self.add_pass(DecomposeSoftmaxesPass())
130+
self.add_pass(FuseBatchnorm2DPass(exported_program))
129131

130132
self.add_pass(AnnotateDecomposedMatmulPass())
131133
self.add_pass(QuantizeFullArgument())
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir import ExportedProgram
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch._export.utils import get_buffer, get_param
11+
from torch.fx import Node
12+
from torch.nn.utils.fusion import fuse_conv_bn_weights
13+
14+
15+
class FuseBatchnorm2DPass(ExportPass):
16+
"""Fuses the pattern convolution -> batchnorm by updating
17+
the weights and bias of the convolution and removing the batchnorm.
18+
"""
19+
20+
def __init__(self, exported_program: ExportedProgram):
21+
self.exported_program = exported_program
22+
super().__init__()
23+
24+
def is_fuseable_conv_bn(self, node: Node):
25+
"""Returns True if node is a batchnorm that can be fused into
26+
a parent convolution."""
27+
if node.op != "call_function":
28+
return False
29+
if node.target not in (
30+
exir_ops.edge.aten._native_batch_norm_legit,
31+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
32+
):
33+
return False
34+
conv = node.all_input_nodes[0]
35+
if conv.target != exir_ops.edge.aten.convolution.default:
36+
return False
37+
# Batchnorm users are getitem, we can only handle those that get first element.
38+
for user in node.users:
39+
get_index = user.args[1]
40+
if get_index != 0:
41+
return False
42+
# Since we change the output of the conv, fuse only if it has single user.
43+
if len(conv.users) > 1:
44+
return False
45+
# For similar reasons, only fuse if conv parameters have single user.
46+
if len(conv.all_input_nodes[1].users) > 1:
47+
return False
48+
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
49+
return False
50+
return True
51+
52+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
53+
modified = False
54+
for node in graph_module.graph.nodes:
55+
if not self.is_fuseable_conv_bn(node):
56+
continue
57+
58+
def get_param_or_none(arg) -> torch.nn.Parameter | None:
59+
"""get_param but check if arg is none first."""
60+
return (
61+
get_param(self.exported_program, arg) if arg is not None else None
62+
)
63+
64+
# Get weight, bias, mean, var and epsilon from the batchnorm
65+
bn = node
66+
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
67+
bn_weight = get_param_or_none(bn_weight_node)
68+
bn_bias = get_param_or_none(bn_bias_node)
69+
70+
running_mean = get_buffer(self.exported_program, bn_mean_node)
71+
running_var = get_buffer(self.exported_program, bn_var_node)
72+
if running_mean is None or running_var is None:
73+
raise ValueError(
74+
"Parameters running_mean and running_var of batchnorm can't be None."
75+
)
76+
epsilon = bn.args[-1]
77+
78+
# Get weight and bias from conv
79+
conv_weight_node, conv_bias_node = conv.args[1:3]
80+
conv_weight = get_param(self.exported_program, conv_weight_node)
81+
conv_bias = get_param_or_none(conv_bias_node)
82+
if conv_weight is None:
83+
raise ValueError("Parameter weight of convolution can't be None.")
84+
85+
# Compute conv parameters folded with batchnorm
86+
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
87+
conv_weight,
88+
conv_bias,
89+
running_mean,
90+
running_var,
91+
epsilon,
92+
bn_weight,
93+
bn_bias,
94+
)
95+
96+
# Set the conv parameters to fused value
97+
def try_set_param(
98+
param_node: Node | None, param_value: torch.nn.Parameter
99+
) -> bool:
100+
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
101+
if param_node is not None:
102+
param_name = (
103+
self.exported_program.graph_signature.inputs_to_parameters[
104+
param_node.name
105+
]
106+
)
107+
self.exported_program.state_dict[param_name] = param_value
108+
return True
109+
return False
110+
111+
try_set_param(conv_weight_node, fused_conv_weight)
112+
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
113+
bn_bias_node, fused_conv_bias
114+
):
115+
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
116+
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
117+
conv.args = conv_args
118+
119+
# Erasing nodes is handled by dead-code elimination.
120+
for user in bn.users:
121+
user.replace_all_uses_with(conv)
122+
modified = True
123+
124+
if modified:
125+
graph_module.graph.eliminate_dead_code()
126+
graph_module.recompile()
127+
graph_module = super().call(graph_module).graph_module
128+
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/test/ops/test_conv_combos.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1717
from executorch.exir.backend.backend_details import CompileSpec
1818
from parameterized import parameterized
19+
from torch.nn.parameter import Parameter
1920

2021
logger = logging.getLogger(__name__)
2122
logger.setLevel(logging.INFO)
@@ -112,12 +113,16 @@ class ComboConvBatchnormRelu6(torch.nn.Module):
112113
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
113114
]
114115

115-
def __init__(self):
116+
def __init__(self, affine: bool):
116117
super().__init__()
117118
self.conv2d = torch.nn.Conv2d(
118119
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
119120
)
120-
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
121+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
122+
self.batch_norm2d.running_mean = torch.rand(3)
123+
self.batch_norm2d.running_var = torch.rand(3)
124+
self.batch_norm2d.weight = Parameter(torch.rand(3))
125+
self.batch_norm2d.bias = Parameter(torch.rand(3))
121126
self.relu6 = torch.nn.ReLU6()
122127

123128
def get_inputs(self) -> Tuple[torch.Tensor]:
@@ -289,24 +294,30 @@ def test_conv_meandim_u85_BI(self):
289294
##############################
290295
## Conv + batch norm + relu ##
291296
##############################
292-
def test_conv_batchnorm_relu6_tosa_MI(self):
293-
model = ComboConvBatchnormRelu6()
297+
affine_params = [("affine", True), ("_no_affine", False)]
298+
299+
@parameterized.expand(affine_params)
300+
def test_conv_batchnorm_relu6_tosa_MI(self, test_suffix, affine):
301+
model = ComboConvBatchnormRelu6(affine)
294302
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
295303

296-
def test_conv_batchnorm_relu6_tosa_BI(self):
297-
model = ComboConvBatchnormRelu6()
304+
@parameterized.expand(affine_params)
305+
def test_conv_batchnorm_relu6_tosa_BI(self, test_suffix, affine):
306+
model = ComboConvBatchnormRelu6(affine)
298307
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
299308

309+
@parameterized.expand(affine_params)
300310
@pytest.mark.corstone_fvp
301-
def test_conv_batchnorm_relu6_u55_BI(self):
302-
model = ComboConvBatchnormRelu6()
311+
def test_conv_batchnorm_relu6_u55_BI(self, test_suffix, affine):
312+
model = ComboConvBatchnormRelu6(affine)
303313
self._test_conv_combo_ethos_BI_pipeline(
304314
model, common.get_u55_compile_spec(), model.get_inputs()
305315
)
306316

317+
@parameterized.expand(affine_params)
307318
@pytest.mark.corstone_fvp
308-
def test_conv_batchnorm_relu_u85_BI(self):
309-
model = ComboConvBatchnormRelu6()
319+
def test_conv_batchnorm_relu_u85_BI(self, test_suffix, affine):
320+
model = ComboConvBatchnormRelu6(affine)
310321
self._test_conv_combo_ethos_BI_pipeline(
311322
model,
312323
common.get_u85_compile_spec(),
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
8+
import torch
9+
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
12+
from parameterized import parameterized
13+
14+
15+
class MergeOneOfTwoBN(torch.nn.Module):
16+
ops_before_pass = {
17+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
18+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
19+
}
20+
ops_after_pass = {
21+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
22+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
23+
}
24+
25+
def __init__(self, affine: bool):
26+
super().__init__()
27+
self.conv2d = torch.nn.Conv2d(
28+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
29+
)
30+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
31+
self.batch_norm2d.running_mean = torch.rand(3)
32+
self.batch_norm2d.running_var = torch.rand(3)
33+
if affine:
34+
self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
35+
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
36+
self.relu6 = torch.nn.ReLU6()
37+
38+
def get_inputs(self) -> tuple[torch.Tensor]:
39+
return (torch.randn(1, 3, 256, 256),)
40+
41+
def forward(self, x):
42+
x = self.conv2d(x)
43+
x = self.batch_norm2d(x)
44+
x = self.relu6(x)
45+
x = self.batch_norm2d(x)
46+
return x
47+
48+
49+
class MergeTwosOfTwoBN(torch.nn.Module):
50+
ops_before_pass = {
51+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
52+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 2,
53+
}
54+
ops_after_pass = {
55+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0,
56+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 2,
57+
}
58+
59+
def __init__(self, affine: bool):
60+
super().__init__()
61+
self.conv2d = torch.nn.Conv2d(
62+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
63+
)
64+
self.conv2d2 = torch.nn.Conv2d(
65+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
66+
)
67+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
68+
self.batch_norm2d.running_mean = torch.rand(3)
69+
self.batch_norm2d.running_var = torch.rand(3)
70+
if affine:
71+
self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
72+
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
73+
self.relu6 = torch.nn.ReLU6()
74+
75+
def get_inputs(self) -> tuple[torch.Tensor]:
76+
return (torch.randn(1, 3, 256, 256),)
77+
78+
def forward(self, x):
79+
x = self.conv2d(x)
80+
x = self.batch_norm2d(x)
81+
x = self.relu6(x)
82+
x = self.conv2d2(x)
83+
x = self.batch_norm2d(x)
84+
return x
85+
86+
87+
class MergeNoBN(torch.nn.Module):
88+
ops_before_pass = {
89+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
90+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
91+
}
92+
ops_after_pass = {
93+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
94+
"executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
95+
}
96+
97+
def __init__(self, affine: bool):
98+
super().__init__()
99+
self.conv2d = torch.nn.Conv2d(
100+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
101+
)
102+
self.conv2d2 = torch.nn.Conv2d(
103+
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
104+
)
105+
self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
106+
self.batch_norm2d.running_mean = torch.rand(3)
107+
self.batch_norm2d.running_var = torch.rand(3)
108+
if affine:
109+
self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
110+
self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
111+
self.relu6 = torch.nn.ReLU6()
112+
113+
def get_inputs(self) -> tuple[torch.Tensor]:
114+
return (torch.randn(1, 3, 256, 256),)
115+
116+
def forward(self, x):
117+
x1 = self.conv2d(x)
118+
x = self.batch_norm2d(x1) # Can't be fused since x1 has multiple users
119+
x = self.relu6(x)
120+
y = self.conv2d2(x1)
121+
z = self.conv2d2(x)
122+
a = self.batch_norm2d(
123+
y
124+
) # Can't be fused since paramters of conv2d2 have multiple users.
125+
126+
return z, a
127+
128+
129+
modules = [
130+
MergeOneOfTwoBN(True),
131+
MergeOneOfTwoBN(False),
132+
MergeTwosOfTwoBN(True),
133+
MergeNoBN(True),
134+
]
135+
136+
137+
class TestFuseBatchnormPass(unittest.TestCase):
138+
139+
@parameterized.expand(modules)
140+
def test_fuse_batchnorm_tosa_MI(self, module):
141+
"""Test various cases where the batchnorm should and shouldn't be fused."""
142+
inputs = module.get_inputs()
143+
test_pass_stage = RunPasses(passes_with_exported_program=[FuseBatchnorm2DPass])
144+
(
145+
(
146+
ArmTester(
147+
module,
148+
example_inputs=inputs,
149+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
150+
)
151+
.export()
152+
.to_edge()
153+
.check_count(module.ops_before_pass)
154+
.run_passes(test_pass_stage)
155+
.check_count(module.ops_after_pass)
156+
.run_method_and_compare_outputs()
157+
)
158+
)

0 commit comments

Comments
 (0)