Skip to content

Commit a9bce9b

Browse files
wwwindhinriksnaer
authored andcommitted
Arm backend: Add decomposition for BatchNorm2D no stats (pytorch#11970)
Signed-off-by: Elena Zhelezina <[email protected]>
1 parent 0cd1b2c commit a9bce9b

File tree

4 files changed

+266
-11
lines changed

4 files changed

+266
-11
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
2525
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
26+
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
2627
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2728
from .decompose_div_pass import DecomposeDivPass # noqa
2829
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ConvertSqueezesToViewPass,
2727
ConvertToClampPass,
2828
DecomposeAvgPool2d,
29+
DecomposeBatchNormNoStatsPass,
2930
DecomposeCosineSimilarityPass,
3031
DecomposeDivPass,
3132
DecomposeEmbeddingPass,
@@ -164,6 +165,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
164165
self.add_pass(DecomposeLeakyReLUPass())
165166
self.add_pass(DecomposeGroupNormPass())
166167
self.add_pass(DecomposeLayerNormPass())
168+
self.add_pass(DecomposeBatchNormNoStatsPass())
167169
self.add_pass(DecomposeVarPass())
168170
self.add_pass(
169171
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import operator
9+
10+
import torch
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import PassResult
15+
16+
17+
class DecomposeBatchNormNoStatsPass(ArmPass):
18+
"""
19+
Decompose BatchNorm2d(track_running_stats=False) (aten._native_batch_norm_legit_no_training)
20+
into a sequence of elementwise operations:
21+
22+
# let input = x, rm = running_mean, rv = running_var, eps: float
23+
rm_view = view(rm, weights_shape)
24+
rv_view = view(rv, weights_shape)
25+
centered = sub(x, rm_view)
26+
eps_full = full(eps_shape, eps)
27+
var_eps = add(rv_view, eps_full)
28+
inv_sqrt = rsqrt(var_eps)
29+
normed = mul(centered, inv_sqrt)
30+
weighted = mul(normed, w_view)
31+
biased = add(weighted, b_view)
32+
33+
Source: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
37+
bn_ops = (
38+
exir_ops.edge.aten._native_batch_norm_legit.no_stats,
39+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
40+
torch.ops.aten._native_batch_norm_legit_no_training.default,
41+
torch.ops.aten.batch_norm.default,
42+
torch.ops.aten.native_batch_norm.default,
43+
)
44+
45+
for node in graph_module.graph.nodes:
46+
if node.op != "call_function" or node.target not in bn_ops:
47+
continue
48+
49+
if node.target in (
50+
torch.ops.aten.batch_norm.default,
51+
torch.ops.aten.native_batch_norm.default,
52+
):
53+
# signature: (input, weight, bias, mean, var, training, momentum, eps, cudnn_enabled)
54+
# pos‐arg 5 is training
55+
training = node.kwargs.get("training", False)
56+
if len(node.args) > 5:
57+
training = node.args[5]
58+
if training:
59+
# skip training‐mode batchnorm
60+
continue
61+
62+
# Extract args
63+
args = node.args
64+
meta = node.meta
65+
66+
# Default eps
67+
eps: float = torch.finfo().eps
68+
# weight and bias may be None
69+
x = args[0]
70+
weight = args[1] if len(args) > 1 else None
71+
bias = args[2] if len(args) > 2 else None
72+
running_mean = args[3]
73+
running_var = args[4]
74+
if len(args) > 6:
75+
eps = args[6]
76+
77+
# Determine shapes
78+
val = meta.get("val")
79+
ref_tensor = val[0] if isinstance(val, tuple) else val
80+
shape = tuple(ref_tensor.size())
81+
dtype = ref_tensor.dtype
82+
rank = len(shape)
83+
84+
# channel dimension is 1 for BatchNorm2d
85+
channel_axis = 1
86+
weights_shape = [1] * rank
87+
weights_shape[channel_axis] = shape[channel_axis]
88+
num_features = shape[channel_axis]
89+
90+
# Ops to use
91+
sub_op = exir_ops.edge.aten.sub.Tensor
92+
view_op = exir_ops.edge.aten.view_copy.default
93+
full_op = exir_ops.edge.aten.full.default
94+
add_op = exir_ops.edge.aten.add.Tensor
95+
rsqrt_op = exir_ops.edge.aten.rsqrt.default
96+
mul_op = exir_ops.edge.aten.mul.Tensor
97+
98+
# Begin decomposition
99+
with graph_module.graph.inserting_before(node):
100+
# reshape running stats
101+
rm_view = create_node(
102+
graph_module.graph,
103+
view_op,
104+
args=(running_mean, weights_shape),
105+
from_node=node,
106+
)
107+
rv_view = create_node(
108+
graph_module.graph,
109+
view_op,
110+
args=(running_var, weights_shape),
111+
from_node=node,
112+
)
113+
# center input
114+
centered = create_node(
115+
graph_module.graph,
116+
sub_op,
117+
args=(x, rm_view),
118+
from_node=node,
119+
)
120+
# epsilon tensor
121+
eps_shape = [1] * rank
122+
eps_full = create_node(
123+
graph_module.graph,
124+
full_op,
125+
args=(eps_shape, eps),
126+
kwargs={"dtype": dtype},
127+
from_node=node,
128+
)
129+
# var + eps
130+
var_eps = create_node(
131+
graph_module.graph,
132+
add_op,
133+
args=(rv_view, eps_full),
134+
from_node=node,
135+
)
136+
# inverse sqrt
137+
inv_sqrt = create_node(
138+
graph_module.graph,
139+
rsqrt_op,
140+
args=(var_eps,),
141+
from_node=node,
142+
)
143+
# normalized
144+
normed = create_node(
145+
graph_module.graph,
146+
mul_op,
147+
args=(centered, inv_sqrt),
148+
from_node=node,
149+
)
150+
151+
# weight
152+
if weight is None:
153+
one = create_node(
154+
graph_module.graph,
155+
full_op,
156+
args=([num_features], 1),
157+
kwargs={"dtype": dtype},
158+
from_node=node,
159+
)
160+
w_view = create_node(
161+
graph_module.graph,
162+
view_op,
163+
args=(one, weights_shape),
164+
from_node=node,
165+
)
166+
else:
167+
w_view = create_node(
168+
graph_module.graph,
169+
view_op,
170+
args=(weight, weights_shape),
171+
from_node=node,
172+
)
173+
weighted = create_node(
174+
graph_module.graph,
175+
mul_op,
176+
args=(normed, w_view),
177+
from_node=node,
178+
)
179+
180+
# bias
181+
if bias is None:
182+
zero = create_node(
183+
graph_module.graph,
184+
full_op,
185+
args=([num_features], 0),
186+
kwargs={"dtype": dtype},
187+
from_node=node,
188+
)
189+
b_view = create_node(
190+
graph_module.graph,
191+
view_op,
192+
args=(zero, weights_shape),
193+
from_node=node,
194+
)
195+
else:
196+
b_view = create_node(
197+
graph_module.graph,
198+
view_op,
199+
args=(bias, weights_shape),
200+
from_node=node,
201+
)
202+
final_out = create_node(
203+
graph_module.graph,
204+
add_op,
205+
args=(weighted, b_view),
206+
from_node=node,
207+
)
208+
209+
users = [u for u in node.users if u is not node]
210+
node.replace_all_uses_with(final_out)
211+
for u in users:
212+
if u.target == operator.getitem:
213+
u.replace_all_uses_with(final_out)
214+
graph_module.graph.erase_node(node)
215+
graph_module.graph.eliminate_dead_code()
216+
217+
graph_module.recompile()
218+
new_gm = super().call(graph_module).graph_module
219+
return PassResult(new_gm, True)

backends/arm/test/ops/test_batch_norm.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ class BatchNorm2dNoStats(torch.nn.Module):
224224
Decomposes into _native_batch_norm_legit.no_stats
225225
"""
226226

227+
aten_ops = ["torch.ops.aten.batch_norm.default"]
228+
227229
def __init__(
228230
self,
229231
num_features: int,
@@ -250,29 +252,60 @@ def forward(self, x):
250252
return self.batch_norm_2d(x)
251253

252254

253-
@pytest.mark.skip(
254-
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
255-
)
256-
def test_native_batch_norm_legit_no_stats_tosa_MI():
257-
pass
255+
@common.parametrize("test_data", test_data_suite)
256+
def test_native_batch_norm_legit_no_stats_tosa_MI(test_data: Tuple):
257+
test_data, model_params = test_data()
258+
pipeline = TosaPipelineMI[input_t1](
259+
BatchNorm2dNoStats(*model_params),
260+
(test_data,),
261+
aten_op=BatchNorm2dNoStats.aten_ops,
262+
)
263+
pipeline.run()
258264

259265

260266
@pytest.mark.skip(
261267
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
262268
)
263-
def test_native_batch_norm_legit_no_stats_tosa_BI():
264-
pass
269+
def test_native_batch_norm_legit_no_stats_tosa_BI(test_data: Tuple):
270+
test_data, model_params = test_data()
271+
pipeline = TosaPipelineBI[input_t1](
272+
BatchNorm2dNoStats(*model_params),
273+
(test_data,),
274+
aten_op=BatchNorm2dNoStats.aten_ops,
275+
qtol=1,
276+
)
277+
pipeline.run()
265278

266279

267280
@pytest.mark.skip(
268281
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
269282
)
270-
def test_native_batch_norm_legit_no_stats_u55_BI():
271-
pass
283+
@common.parametrize("test_data", test_data_suite)
284+
@common.XfailIfNoCorstone300
285+
def test_native_batch_norm_legit_no_stats_u55_BI(test_data: Tuple):
286+
test_data, model_params = test_data()
287+
pipeline = EthosU55PipelineBI[input_t1](
288+
BatchNorm2dNoStats(*model_params),
289+
(test_data,),
290+
aten_op=BatchNorm2dNoStats.aten_ops,
291+
run_on_fvp=True,
292+
qtol=1,
293+
)
294+
pipeline.run()
272295

273296

274297
@pytest.mark.skip(
275298
reason="MLETORCH-999: Add support for _native_batch_norm_legit.no_stats."
276299
)
277-
def test_native_batch_norm_legit_no_stats_u85_BI():
278-
pass
300+
@common.parametrize("test_data", test_data_suite)
301+
@common.XfailIfNoCorstone320
302+
def test_native_batch_norm_legit_no_stats_u85_BI(test_data: Tuple):
303+
test_data, model_params = test_data()
304+
pipeline = EthosU85PipelineBI[input_t1](
305+
BatchNorm2dNoStats(*model_params),
306+
(test_data,),
307+
aten_op=BatchNorm2dNoStats.aten_ops,
308+
run_on_fvp=False,
309+
qtol=1,
310+
)
311+
pipeline.run()

0 commit comments

Comments
 (0)