Skip to content

Commit 0966d99

Browse files
Arm backend: Move batchnorm decomposition to pass (#8050)
Move batchnorm decomposition to pass The decomposition logic of batchnorm is better suited for a pass than a node visitor. Signed-off-by: Oscar Andersson <[email protected]>
1 parent c9f5f19 commit 0966d99

File tree

6 files changed

+146
-254
lines changed

6 files changed

+146
-254
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
2525
ConvertSqueezesToViewPass,
2626
)
27+
from executorch.backends.arm._passes.decompose_batchnorm_pass import (
28+
DecomposeBatchNormPass,
29+
)
2730
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
2831
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2932
DecomposeLayerNormPass,
@@ -87,6 +90,7 @@ def _transform(self, graph_module: GraphModule):
8790
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
8891
self.add_pass(FuseQuantizedActivationPass())
8992
self.add_pass(RemoveGetItemPass())
93+
self.add_pass(DecomposeBatchNormPass())
9094
self.add_pass(ConvertSplitToSlicePass())
9195
self.add_pass(ConvertMmToBmmPass())
9296
self.add_pass(DecomposeLinearPass())
@@ -121,6 +125,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121125
self.add_pass(ConvertSplitToSlicePass())
122126
self.add_pass(ConvertMmToBmmPass())
123127
self.add_pass(DecomposeLinearPass())
128+
self.add_pass(DecomposeBatchNormPass())
124129
self.add_pass(DecomposeLayerNormPass())
125130
self.add_pass(DecomposeVarPass())
126131
self.add_pass(DecomposeMeanDimPass())
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import operator
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
edge_bn_ops = (exir_ops.edge.aten._native_batch_norm_legit_no_training.default,)
18+
19+
20+
def get_bn_decomposition(op) -> tuple:
21+
"""
22+
Returns decomposition of batchnorm in edge ops.
23+
Raises RuntimeError if op is not batchnorm edge op.
24+
"""
25+
if op in edge_bn_ops:
26+
return (
27+
exir_ops.edge.aten.sub.Tensor,
28+
exir_ops.edge.aten.add.Tensor,
29+
exir_ops.edge.aten.rsqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.view_copy.default,
32+
exir_ops.edge.aten.full.default,
33+
)
34+
else:
35+
raise RuntimeError(f"Can't get decomposition for {op}")
36+
37+
38+
class DecomposeBatchNormPass(ExportPass):
39+
"""
40+
Decompose BatchNorm to:
41+
%output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta
42+
e.g.
43+
%output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias
44+
->
45+
%op1 = sub(%activations, %running_mean)
46+
%op2 = add(%running_var, %epsilon_const)
47+
%op3 = rsqrt(%op2)
48+
%op4 = mul(%op1, %op3)
49+
%op5 = mul(%op4, %weights)
50+
%output = add(%op5, %bias)
51+
"""
52+
53+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
54+
modified = False
55+
for node in graph_module.graph.nodes:
56+
if node.op != "call_function" or node.target not in edge_bn_ops:
57+
continue
58+
59+
args = node.args
60+
meta = node.meta
61+
(
62+
activations,
63+
weights,
64+
bias,
65+
running_mean,
66+
running_var,
67+
momentum,
68+
epsilon,
69+
) = args
70+
if momentum != 0.1:
71+
raise RuntimeError(f"Expected momenttum=0.1 but got {momentum}")
72+
73+
shape = meta["val"][0].size()
74+
dtype = meta["val"][0].dtype
75+
rank = len(shape)
76+
running_mean_shape = running_mean.meta["val"].shape
77+
running_mean_reshaped_shape = [1] * rank
78+
running_mean_reshaped_shape[1] = running_mean_shape[0]
79+
epsilon_reshaped_shape = [1] * rank
80+
81+
sub, add, rsqrt, mul, view, full = get_bn_decomposition(node.target)
82+
with graph_module.graph.inserting_before(node):
83+
mean_reshaped = create_node(
84+
graph_module.graph,
85+
view,
86+
args=(running_mean, running_mean_reshaped_shape),
87+
)
88+
op1 = create_node(
89+
graph_module.graph, sub, args=(activations, mean_reshaped)
90+
)
91+
full = create_node(
92+
graph_module.graph,
93+
full,
94+
args=(epsilon_reshaped_shape, epsilon),
95+
kwargs={"dtype": dtype},
96+
)
97+
var_reshaped = create_node(
98+
graph_module.graph,
99+
view,
100+
args=(running_var, running_mean_reshaped_shape),
101+
)
102+
op2 = create_node(graph_module.graph, add, args=(var_reshaped, full))
103+
op3 = create_node(graph_module.graph, rsqrt, args=(op2,))
104+
op4 = create_node(graph_module.graph, mul, args=(op1, op3))
105+
if weights is not None:
106+
weights_reshaped = create_node(
107+
graph_module.graph,
108+
view,
109+
args=(weights, running_mean_reshaped_shape),
110+
)
111+
op5 = create_node(
112+
graph_module.graph, mul, args=(op4, weights_reshaped)
113+
)
114+
else:
115+
op5 = op4
116+
output = op5
117+
if bias is not None:
118+
bias_reshaped_shape = running_mean_reshaped_shape
119+
bias_reshaped = create_node(
120+
graph_module.graph, view, args=(bias, bias_reshaped_shape)
121+
)
122+
output = create_node(
123+
graph_module.graph, add, args=(op5, bias_reshaped)
124+
)
125+
126+
users = [user for user in node.users if node != user]
127+
node.replace_all_uses_with(output)
128+
for user in users:
129+
if user.target == operator.getitem:
130+
user.replace_all_uses_with(output)
131+
graph_module.graph.erase_node(node)
132+
graph_module.graph.eliminate_dead_code()
133+
modified = True
134+
if modified:
135+
graph_module.recompile()
136+
graph_module = super().call(graph_module).graph_module
137+
138+
return PassResult(graph_module, modified)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -82,9 +82,10 @@ def call(self, graph_module: torch.fx.GraphModule):
8282
n_dims = len(normalized_shape)
8383
if isinstance(meta["val"], tuple):
8484
shape = meta["val"][0].size()
85+
dtype = meta["val"][0].dtype
8586
else:
8687
shape = meta["val"].size()
87-
dtype = meta["val"][0].dtype
88+
dtype = meta["val"].dtype
8889
rank = len(shape)
8990
dims = list(range(-1, -1 * (n_dims + 1), -1))
9091
dims = [dim % rank for dim in dims]

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
node_visitor,
1010
op_add,
1111
op_avg_pool2d,
12-
op_batch_norm,
1312
op_bmm,
1413
op_cat,
1514
op_clamp,

0 commit comments

Comments
 (0)