Skip to content

Commit 3a62512

Browse files
committed
Add standalone batch norm support via depthwise conv conversion.
1 parent daebcde commit 3a62512

File tree

6 files changed

+523
-16
lines changed

6 files changed

+523
-16
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
1515
Conv1dUnsqueezePass,
1616
)
17+
from executorch.backends.xnnpack._passes.convert_batch_norm_to_depthwise_conv import (
18+
ConvertBatchNormToDepthwiseConvPass,
19+
)
1720
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1821
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
1922
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
@@ -64,6 +67,7 @@ def __init__(
6467
ConvertToSDPAPass,
6568
ConstPropPass,
6669
FuseBatchNormWithConvPass,
70+
ConvertBatchNormToDepthwiseConvPass,
6771
FuseActivationPass,
6872
DecomposeConcatenate,
6973
RemoveGetItemPass,
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
from typing import Optional
9+
10+
import torch
11+
from executorch.backends.transforms.utils import (
12+
create_constant_placeholder,
13+
delete_constant_placeholder,
14+
)
15+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16+
from executorch.backends.xnnpack.utils.utils import (
17+
get_param_tensor,
18+
get_tensor_name,
19+
is_param_node,
20+
)
21+
from executorch.exir import ExportedProgram
22+
from executorch.exir.dialects._ops import ops as exir_ops
23+
from executorch.exir.pass_base import PassResult
24+
from torch.export.graph_signature import InputKind
25+
26+
27+
class ConvertBatchNormToDepthwiseConvPass(XNNPACKPass):
28+
"""
29+
Converts standalone batch norm operations to depthwise convolutions.
30+
This allows XNNPACK to handle batch norm operations that cannot be fused
31+
with preceding convolutions.
32+
33+
BatchNorm formula: y = (x - mean) / sqrt(var + eps) * weight + bias
34+
This can be represented as a 1x1 depthwise convolution with:
35+
- conv_weight = weight / sqrt(var + eps)
36+
- conv_bias = bias - mean * weight / sqrt(var + eps)
37+
"""
38+
39+
def call(self, graph_module: torch.fx.GraphModule):
40+
graph = graph_module.graph
41+
constant_placeholders_to_delete = set()
42+
nodes_to_convert = []
43+
44+
# First pass: identify standalone batch norm nodes
45+
for node in graph.nodes:
46+
if (
47+
node.target != exir_ops.edge.aten._native_batch_norm_legit_no_training.default
48+
and node.target != exir_ops.edge.aten.native_batch_norm.default
49+
):
50+
continue
51+
52+
# Check if this batch norm can be fused with a preceding conv
53+
# If so, skip it - the fusion pass will handle it
54+
if self._can_be_fused_with_conv(node):
55+
continue
56+
57+
# Check if this is a valid standalone batch norm to convert
58+
if self._can_convert_to_depthwise_conv(node):
59+
nodes_to_convert.append(node)
60+
61+
# Second pass: convert the identified nodes
62+
for bn_node in nodes_to_convert:
63+
conv_node = self._convert_batch_norm_to_depthwise_conv(
64+
graph_module, bn_node, constant_placeholders_to_delete
65+
)
66+
if conv_node is not None:
67+
# Replace all uses of batch norm getitem(0) with the conv node
68+
for user in list(bn_node.users):
69+
if user.target == operator.getitem and user.args[1] == 0:
70+
user.replace_all_uses_with(conv_node)
71+
graph.erase_node(user)
72+
73+
# Remove the batch norm node
74+
graph.erase_node(bn_node)
75+
76+
# Clean up unused constant placeholders
77+
if constant_placeholders_to_delete:
78+
graph_module.graph.eliminate_dead_code()
79+
for node in constant_placeholders_to_delete:
80+
if node is not None and len(node.users) == 0:
81+
delete_constant_placeholder(self.exported_program, node)
82+
83+
graph_module.recompile()
84+
# Regenerate metadata and shape information
85+
graph_module = super().call(graph_module).graph_module
86+
87+
return PassResult(graph_module, True)
88+
89+
def _can_be_fused_with_conv(self, bn_node: torch.fx.Node) -> bool:
90+
"""Check if this batch norm can be fused with a preceding convolution."""
91+
# Import here to avoid circular dependency
92+
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
93+
FuseBatchNormWithConvPass,
94+
)
95+
96+
input_node = bn_node.all_input_nodes[0]
97+
98+
# Check if input is a conv with single user (this batch norm)
99+
if (
100+
input_node.target == exir_ops.edge.aten.convolution.default
101+
and len(input_node.users) == 1
102+
):
103+
return FuseBatchNormWithConvPass.can_fuse(
104+
input_node, bn_node, self.exported_program
105+
)
106+
107+
return False
108+
109+
def _can_convert_to_depthwise_conv(self, bn_node: torch.fx.Node) -> bool:
110+
"""Check if this batch norm can be converted to depthwise conv."""
111+
112+
# All users must be getitem ops accessing the first element (output tensor)
113+
for user in bn_node.users:
114+
if user.target != operator.getitem or user.args[1] != 0:
115+
return False
116+
117+
# Check that we have the required parameters
118+
if len(bn_node.args) < 5:
119+
return False
120+
121+
# Weight, bias, running_mean, running_var must be parameters
122+
param_nodes = bn_node.args[1:5] # weight, bias, running_mean, running_var
123+
124+
for param_node in param_nodes:
125+
if not isinstance(param_node, torch.fx.Node):
126+
return False
127+
if not is_param_node(self.exported_program, param_node):
128+
return False
129+
130+
return True
131+
132+
def _convert_batch_norm_to_depthwise_conv(
133+
self,
134+
graph_module: torch.fx.GraphModule,
135+
bn_node: torch.fx.Node,
136+
constant_placeholders_to_delete: set,
137+
) -> Optional[torch.fx.Node]:
138+
"""Convert a batch norm node to a depthwise convolution."""
139+
140+
# Extract batch norm parameters
141+
input_tensor = bn_node.args[0]
142+
143+
# Cast args to Node types for parameter access
144+
bn_weight_node = bn_node.args[1] if isinstance(bn_node.args[1], torch.fx.Node) else None
145+
bn_bias_node = bn_node.args[2] if isinstance(bn_node.args[2], torch.fx.Node) else None
146+
running_mean_node = bn_node.args[3] if isinstance(bn_node.args[3], torch.fx.Node) else None
147+
running_var_node = bn_node.args[4] if isinstance(bn_node.args[4], torch.fx.Node) else None
148+
149+
if any(node is None for node in [bn_weight_node, bn_bias_node, running_mean_node, running_var_node]):
150+
return None
151+
152+
# These are guaranteed to be non-None now
153+
assert bn_weight_node is not None
154+
assert bn_bias_node is not None
155+
assert running_mean_node is not None
156+
assert running_var_node is not None
157+
158+
bn_weight = get_param_tensor(self.exported_program, bn_weight_node)
159+
bn_bias = get_param_tensor(self.exported_program, bn_bias_node)
160+
running_mean = get_param_tensor(self.exported_program, running_mean_node)
161+
running_var = get_param_tensor(self.exported_program, running_var_node)
162+
163+
# Get epsilon value
164+
if str(bn_node.target).endswith("native_batch_norm.default"):
165+
eps = bn_node.args[7] if len(bn_node.args) > 7 else 1e-5
166+
else: # _native_batch_norm_legit_no_training
167+
eps = bn_node.args[6] if len(bn_node.args) > 6 else 1e-5
168+
169+
# Ensure eps is a float
170+
if not isinstance(eps, (int, float)):
171+
eps = 1e-5
172+
173+
if any(param is None for param in [bn_weight, bn_bias, running_mean, running_var]):
174+
return None
175+
176+
# Ensure all parameters are tensors
177+
assert isinstance(bn_weight, torch.Tensor)
178+
assert isinstance(bn_bias, torch.Tensor)
179+
assert isinstance(running_mean, torch.Tensor)
180+
assert isinstance(running_var, torch.Tensor)
181+
182+
# Calculate depthwise conv parameters
183+
# BatchNorm: y = (x - mean) / sqrt(var + eps) * weight + bias
184+
# Depthwise Conv: y = x * conv_weight + conv_bias
185+
# Therefore: conv_weight = weight / sqrt(var + eps)
186+
# conv_bias = bias - mean * weight / sqrt(var + eps)
187+
188+
inv_std = torch.rsqrt(running_var + eps)
189+
conv_weight_1d = bn_weight * inv_std
190+
conv_bias_1d = bn_bias - running_mean * conv_weight_1d
191+
192+
# Reshape for depthwise conv: [C] -> [C, 1, 1, 1] for 2D conv
193+
# Assuming 4D input tensor [N, C, H, W]
194+
num_channels = conv_weight_1d.shape[0]
195+
conv_weight = conv_weight_1d.view(num_channels, 1, 1, 1)
196+
conv_bias = conv_bias_1d
197+
198+
# Create parameter names
199+
bn_weight_name = get_tensor_name(self.exported_program, bn_weight_node)
200+
conv_weight_name = (bn_weight_name + "_as_depthwise_conv_weight").replace(".", "_")
201+
conv_bias_name = (bn_weight_name + "_as_depthwise_conv_bias").replace(".", "_")
202+
203+
# Create new parameter nodes
204+
graph = graph_module.graph
205+
with graph.inserting_before(bn_node):
206+
conv_weight_node = create_constant_placeholder(
207+
exp_program=self.exported_program,
208+
graph=graph,
209+
kind=InputKind.PARAMETER,
210+
name=conv_weight_name,
211+
data=conv_weight,
212+
)
213+
214+
conv_bias_node = create_constant_placeholder(
215+
exp_program=self.exported_program,
216+
graph=graph,
217+
kind=InputKind.PARAMETER,
218+
name=conv_bias_name,
219+
data=conv_bias,
220+
)
221+
222+
# Create depthwise convolution node
223+
# Args: input, weight, bias, stride, padding, dilation, transposed, output_padding, groups
224+
conv_args = (
225+
input_tensor, # input
226+
conv_weight_node, # weight
227+
conv_bias_node, # bias
228+
[1, 1], # stride
229+
[0, 0], # padding
230+
[1, 1], # dilation
231+
False, # transposed
232+
[0, 0], # output_padding
233+
num_channels, # groups (depthwise = groups = in_channels)
234+
)
235+
236+
conv_node = graph.create_node(
237+
"call_function",
238+
exir_ops.edge.aten.convolution.default,
239+
args=conv_args,
240+
)
241+
242+
# Mark old parameters for deletion
243+
constant_placeholders_to_delete.update(bn_node.args[1:5])
244+
245+
return conv_node
246+
247+
@staticmethod
248+
def can_convert_standalone_batch_norm(
249+
bn_node: torch.fx.Node, program: ExportedProgram
250+
) -> bool:
251+
"""
252+
Static method to check if a standalone batch norm can be converted.
253+
Used by the partitioner configuration.
254+
"""
255+
# All users must be getitem ops accessing the first element
256+
for user in bn_node.users:
257+
if user.target != operator.getitem or user.args[1] != 0:
258+
return False
259+
260+
# Check that we have required parameters
261+
if len(bn_node.args) < 5:
262+
return False
263+
264+
# Weight, bias, running_mean, running_var must be parameters
265+
param_nodes = bn_node.args[1:5]
266+
267+
for param_node in param_nodes:
268+
if not isinstance(param_node, torch.fx.Node):
269+
return False
270+
if not is_param_node(program, param_node):
271+
return False
272+
273+
return True

backends/xnnpack/partition/config/node_configs.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from typing import List, Optional
1010

1111
import torch
12+
from executorch.backends.xnnpack._passes.convert_batch_norm_to_depthwise_conv import (
13+
ConvertBatchNormToDepthwiseConvPass,
14+
)
1215
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
1316
FuseBatchNormWithConvPass,
1417
)
@@ -35,20 +38,20 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
3538
return False
3639

3740
bn = node
38-
conv = node.all_input_nodes[0]
39-
40-
if conv.op != "call_function":
41-
return False
42-
43-
conv_name = format_target_name(conv.target.__name__) # pyre-ignore
44-
45-
if conv_name not in ["convolution.default"]:
46-
why(node, f"Invalid conv target {conv_name}")
47-
return False
41+
input_node = node.all_input_nodes[0]
4842

49-
can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
50-
if not can_fuse:
51-
why(node, "BatchNorm cannot be fused with Convolution")
43+
# First check if this can be fused with a convolution
44+
if input_node.op == "call_function":
45+
conv_name = format_target_name(input_node.target.__name__) # pyre-ignore
46+
if conv_name in ["convolution.default"]:
47+
can_fuse = FuseBatchNormWithConvPass.can_fuse(input_node, bn, ep)
48+
if can_fuse:
49+
return True
50+
51+
# If not fuseable with conv, check if it can be converted to depthwise conv
52+
can_convert = ConvertBatchNormToDepthwiseConvPass.can_convert_standalone_batch_norm(bn, ep)
53+
if not can_convert:
54+
why(node, "BatchNorm cannot be fused with Convolution or converted to depthwise conv")
5255
return False
5356

5457
return True

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,36 @@ def test_q8_batch_norm_fusion(self):
7171
.run_method_and_compare_outputs()
7272
)
7373

74+
def test_fp32_standalone_batch_norm_converts_to_depthwise_conv(self):
75+
"""
76+
Test that standalone batch norms (i.e. batch norms that are not fused with a conv)
77+
can be converted to depthwise convolutions and successfully partitioned and lowered.
78+
"""
79+
80+
class StandaloneBN(torch.nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
self.bn = torch.nn.BatchNorm2d(2)
84+
# Run forward to set up batch norm statistics
85+
self.forward(torch.randn(2, 2, 4, 4) * 2 + 2)
86+
87+
def forward(self, x):
88+
return self.bn(x)
89+
90+
(
91+
Tester(StandaloneBN().eval(), (torch.randn(2, 2, 4, 4),))
92+
.export()
93+
.to_edge()
94+
.check_count({self.bn_name: 1})
95+
.partition()
96+
.check_count({self.bn_name: 0}) # Should be partitioned and converted
97+
.run_method_and_compare_outputs()
98+
)
99+
74100
def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
75101
"""
76-
We do not currently support standalone batch norms (i.e. batch norms that are
77-
not fused with a conv). This is planned, but until implemented, this test ensures
78-
that we do not partition the standalone batch norm and then fail to lower.
102+
DEPRECATED: We now support standalone batch norms by converting them to depthwise conv.
103+
This test remains for backwards compatibility but may be removed in the future.
79104
"""
80105

81106
class BN(torch.nn.Module):
@@ -86,6 +111,8 @@ def __init__(self):
86111
def forward(self, x):
87112
return self.bn(x)
88113

114+
# Note: This test is now testing the old behavior where standalone batch norms
115+
# without proper initialization may not be convertible
89116
(
90117
Tester(BN(), (torch.randn(2, 2, 4, 4),))
91118
.export()

0 commit comments

Comments
 (0)