Skip to content

Commit 1aa7400

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support un-fused batchnorm1d/2d on XNNPACK via decomposition (#16533)
Summary: Add a new pass - DecomposeBatchNorm - which converts standalone (non-fused) batch norm operators to 1x1 depthwise convolution. This prevents delegation graph breaks when batch norm operators can't be fused. Differential Revision: D90422630
1 parent 3dd80c1 commit 1aa7400

File tree

8 files changed

+663
-63
lines changed

8 files changed

+663
-63
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2424
ConvertToUpsampleBilinear2d,
2525
)
26+
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
2627
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2728
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2829
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
@@ -76,6 +77,7 @@ def __init__(
7677
ConvertToSDPAPass,
7778
ConstPropPass,
7879
FuseBatchNormPass,
80+
DecomposeBatchNorm,
7981
FuseActivationPass,
8082
DecomposeConcatenate,
8183
RemoveGetItemPass,
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import operator
9+
10+
import torch
11+
from executorch.backends.transforms.utils import create_constant_placeholder
12+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
13+
from executorch.backends.xnnpack.utils.utils import (
14+
check_or_raise,
15+
get_param_tensor,
16+
get_tensor_name,
17+
is_param_node,
18+
)
19+
from executorch.exir.backend.utils import WhyNoPartition
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
from torch.export.graph_signature import InputKind
22+
from torch.fx.passes.infra.pass_base import PassResult
23+
24+
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.WARNING)
26+
27+
28+
class DecomposeBatchNorm(XNNPACKPass):
29+
"""
30+
Decompose batchnorm operators into 1x1 depthwise convolution.
31+
"""
32+
33+
BATCH_NORM_OPS = {
34+
exir_ops.edge.aten.native_batch_norm.default,
35+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
36+
}
37+
38+
@staticmethod
39+
def can_decompose_batch_norm(
40+
node: torch.fx.Node,
41+
exported_program: torch.export.ExportedProgram,
42+
why: WhyNoPartition | None = None,
43+
) -> bool:
44+
"""
45+
Determine whether the given batch norm node can be decomposed by this pass.
46+
"""
47+
48+
if node.op != "call_function" or node.target not in DecomposeBatchNorm.BATCH_NORM_OPS:
49+
return False
50+
51+
input_meta = node.args[0].meta["val"]
52+
53+
# Since we're converting to conv and XNNPACK doesn't support conv3d, we can't
54+
# handle BatchNorm3d. Validate the input dimension. We'll take NC, NCL, or NCHW.
55+
if input_meta.dim() not in (2, 3, 4):
56+
if why:
57+
why(node, f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.")
58+
return False
59+
60+
# The batch norm node returns a tuple of output and other stuff we don't care about.
61+
# All users must be getitem nodes that fetch the output (index 0).
62+
# The partitioner should enforce this, but we'll check it here too.
63+
for user in node.users:
64+
if user.target != operator.getitem or user.args[1] != 0:
65+
if why:
66+
why(node, "Batch norm users must only access the output tensor.")
67+
return False
68+
69+
# Channel dimension and non-input args must be statically known.
70+
if not isinstance(input_meta.shape[1], int):
71+
if why:
72+
why(node, f"Channel dimension must be statically known, but was {input_meta.shape[1]}.")
73+
return False
74+
75+
if not is_param_node(exported_program, node.args[1]) or not is_param_node(exported_program, node.args[2]):
76+
if why:
77+
why(node, "Batch norm affine weight and bias must be static.")
78+
return False
79+
80+
if not is_param_node(exported_program, node.args[3]) or not is_param_node(exported_program, node.args[4]):
81+
if why:
82+
why(node, "Batch norm running mean and variance must be static.")
83+
return False
84+
85+
if isinstance(node.args[-1], torch.fx.Node):
86+
if why:
87+
why(node, "Batch norm epsilon must be static.")
88+
return False
89+
90+
return True
91+
92+
@staticmethod
93+
def compute_w_and_b(
94+
eps: float,
95+
running_mean: torch.Tensor, # [C]
96+
running_var: torch.Tensor, # [C]
97+
gamma: torch.Tensor, # [C], learned weight
98+
beta: torch.Tensor, # [C], learned bias
99+
) -> (torch.Tensor, torch.Tensor):
100+
"""
101+
Compute equivalent per-channel weight and bias to match the batch norm
102+
computation with frozen values.
103+
"""
104+
105+
# See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
106+
denom = torch.sqrt(running_var + torch.Tensor([eps]))
107+
weight = gamma / denom
108+
bias = -running_mean * gamma / denom + beta
109+
110+
return weight, bias
111+
112+
def replace_bn_node_with_conv(
113+
self,
114+
bn_node: torch.fx.Node,
115+
graph_module: torch.fx.GraphModule,
116+
) -> torch.fx.Node:
117+
"""
118+
Replace a BatchNorm with NCL or NCHW input with an equivalent depthwise
119+
convolution.
120+
"""
121+
122+
# Compute the equivalent per-channel weights and biases.
123+
# Note that the batch norm node args are
124+
# (input, gamma, beta, running_mean, running_var, [training], momentum, eps).
125+
# The training arg is not present in the _no_training variant.
126+
weight, bias = DecomposeBatchNorm.compute_w_and_b(
127+
eps=bn_node.args[-1],
128+
running_mean=get_param_tensor(self.exported_program, bn_node.args[3]),
129+
running_var=get_param_tensor(self.exported_program, bn_node.args[4]),
130+
gamma=get_param_tensor(self.exported_program, bn_node.args[1]),
131+
beta=get_param_tensor(self.exported_program, bn_node.args[2]),
132+
)
133+
134+
# Conv weights have shape [out_c, in_c/g, spatial...].
135+
# For dw, in_c = g. The kernel is also 1x1 (or just 1, for 1d).
136+
#
137+
# BatchNorm weights have shape [in_c].
138+
# So we just need to unsqueeze the [in_c] to to [in_c, 1, 1, [1]].
139+
input_meta = bn_node.args[0].meta["val"]
140+
channel_count = input_meta.shape[1]
141+
spatial_dims = max(input_meta.dim() - 2, 1) # Min of 1 since 1d can be NC or NCL.
142+
new_weight_shape = [weight.shape[0], 1] + [1] * spatial_dims
143+
weight = weight.reshape(new_weight_shape)
144+
145+
# Generate names for the new weight and bias parameters based on the original
146+
# batch norm gamma parameter name.
147+
gamma_name = get_tensor_name(self.exported_program, bn_node.args[1])
148+
weight_name = (gamma_name + "_decomposed_bn_weight").replace(".", "_")
149+
bias_name = (gamma_name + "_decomposed_bn_bias").replace(".", "_")
150+
151+
# Insert the new weight and bias as constant placeholders in the graph.
152+
with graph_module.graph.inserting_before(bn_node.args[1]):
153+
weight_node = create_constant_placeholder(
154+
exp_program=self.exported_program,
155+
graph=graph_module.graph,
156+
kind=InputKind.PARAMETER,
157+
name=weight_name,
158+
data=weight,
159+
)
160+
bias_node = create_constant_placeholder(
161+
exp_program=self.exported_program,
162+
graph=graph_module.graph,
163+
kind=InputKind.PARAMETER,
164+
name=bias_name,
165+
data=bias,
166+
)
167+
168+
with graph_module.graph.inserting_after(bn_node):
169+
conv_node = graph_module.graph.call_function(
170+
exir_ops.edge.aten.convolution.default,
171+
args=(
172+
bn_node.args[0], # Input
173+
weight_node, # Weight
174+
bias_node, # Bias
175+
[1] * spatial_dims, # Stride
176+
[0] * spatial_dims, # Padding
177+
[1] * spatial_dims, # Dilation
178+
False, # Transposed
179+
[0] * spatial_dims, # Output_padding
180+
channel_count, # Groups (depthwise, so groups=in_channels)
181+
))
182+
183+
# Find the getitem user nodes and replace them with the conv node.
184+
# The decomp checks above enforce that the node is only used by getitem[0].
185+
users = list(bn_node.users)
186+
for user in users:
187+
user.replace_all_uses_with(conv_node)
188+
graph_module.graph.erase_node(user)
189+
190+
graph_module.graph.erase_node(bn_node)
191+
return conv_node
192+
193+
194+
def decompose_node(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None:
195+
input_meta = node.args[0].meta["val"]
196+
197+
# These should be checked by the partitioner and calling node,
198+
# so we should never fail these checks.
199+
check_or_raise(
200+
node.op == "call_function" and node.target in DecomposeBatchNorm.BATCH_NORM_OPS,
201+
f"Invalid batch norm operator {node.op}.")
202+
203+
check_or_raise(
204+
input_meta.dim() in (2, 3, 4),
205+
f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.")
206+
207+
channel_count = input_meta.shape[1]
208+
check_or_raise(
209+
isinstance(channel_count, int),
210+
f"Channel dimension must be statically known, but was {channel_count}.")
211+
212+
# Create the convolution node.
213+
conv_node = self.replace_bn_node_with_conv(node, graph_module)
214+
215+
# BatchNorm1d can be NC or NCL. Conv1d requies the L dim, so unsqueeze NC -> NCL.
216+
if input_meta.dim() == 2:
217+
with graph_module.graph.inserting_before(conv_node):
218+
# Insert unsqueeze node before.
219+
unsqueeze_node = graph_module.graph.call_function(
220+
exir_ops.edge.aten.unsqueeze_copy.default,
221+
args=(conv_node.args[0], 2))
222+
conv_node.args = (unsqueeze_node, *conv_node.args[1:])
223+
224+
with graph_module.graph.inserting_after(conv_node):
225+
# Insert squeeze node after.
226+
squeeze_node = graph_module.graph.call_function(
227+
exir_ops.edge.aten.squeeze_copy.dim,
228+
args=(conv_node, 2))
229+
conv_node.replace_all_uses_with(squeeze_node)
230+
# This gets overwritten by replace_all_uses_with. Maybe there's
231+
# a better solution?
232+
squeeze_node.args = (conv_node, *squeeze_node.args[1:])
233+
234+
# override
235+
def call(self, graph_module: torch.fx.GraphModule):
236+
# Find and transform all eligible batch norm nodes.
237+
for node in graph_module.graph.nodes:
238+
if node.op == "call_function" and node.target in self.BATCH_NORM_OPS:
239+
if self.can_decompose_batch_norm(node, self.exported_program):
240+
self.decompose_node(node, graph_module)
241+
242+
graph_module.recompile()
243+
244+
# Propagate metadata and retrace module
245+
graph_module = super().call(graph_module).graph_module
246+
247+
return PassResult(graph_module, True)

backends/xnnpack/_passes/fuse_batch_norm.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,17 @@
1111
create_constant_placeholder,
1212
delete_constant_placeholder,
1313
)
14-
1514
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16-
1715
from executorch.backends.xnnpack.utils.utils import (
1816
get_param_tensor,
1917
get_tensor_name,
2018
is_param_node,
2119
)
2220
from executorch.exir import ExportedProgram
21+
from executorch.exir.backend.utils import WhyNoPartition
2322
from executorch.exir.dialects._ops import ops as exir_ops
2423
from executorch.exir.pass_base import PassResult
2524
from torch.export.graph_signature import InputKind
26-
2725
from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights
2826

2927

@@ -85,11 +83,23 @@ def can_fuse(
8583
input_node: torch.fx.Node,
8684
bn: torch.fx.Node,
8785
program: ExportedProgram,
86+
why: WhyNoPartition | None = None,
8887
) -> bool:
8988
"""
9089
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
9190
"""
9291

92+
if input_node.op != "call_function":
93+
return False
94+
95+
if input_node.target not in (
96+
exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.linear.default
97+
):
98+
if why:
99+
why("Input node must be a convolution or linear op.")
100+
return False
101+
102+
93103
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
94104

95105
# All users of the batch_norm node must be getitem ops.
@@ -98,6 +108,8 @@ def can_fuse(
98108
if [
99109
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
100110
].count(False):
111+
if why:
112+
why("Batch norm users must only access the output tensor.")
101113
return False
102114

103115
input_node_weights = input_node.args[1]
@@ -107,11 +119,15 @@ def can_fuse(
107119
if not isinstance(input_node_weights, torch.fx.Node) or not isinstance(
108120
bn_weights, torch.fx.Node
109121
):
122+
if why:
123+
why("Input node weights must be parameters.")
110124
return False
111125

112126
if [
113127
is_param_node(program, node) for node in {input_node_weights, bn_weights}
114128
].count(False):
129+
if why:
130+
why("Node weights must be static.")
115131
return False
116132

117133
# Check the rank of the convolutution input - only Conv1d and 2d are supported.
@@ -122,6 +138,8 @@ def can_fuse(
122138
or "val" not in conv_input.meta
123139
or len(conv_input.meta["val"].shape) not in (3, 4)
124140
):
141+
if why:
142+
why("Convolution input must be rank 3 or 4.")
125143
return False
126144

127145
return True

backends/xnnpack/operators/op_squeeze.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def define_node(
3636
debug_handle: int,
3737
) -> None:
3838

39+
dim = cast(int, node.args[1])
3940
check_or_raise(
40-
cast(int, node.args[1]) == -1,
41+
dim == -1 or dim == len(node.args[0].meta["val"].shape) - 1,
4142
"XNNPACK currently only supports squeezing in last dimension",
4243
)
4344

@@ -98,8 +99,9 @@ def define_node(
9899
debug_handle: int,
99100
) -> None:
100101

102+
dim = cast(int, node.args[1])
101103
check_or_raise(
102-
cast(int, node.args[1]) == -1,
104+
dim == -1 or dim == len(node.args[0].meta["val"].shape),
103105
"XNNPACK currently only supports unsqueezing in last dimension",
104106
)
105107

backends/xnnpack/partition/config/node_configs.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import List, Optional
1010

1111
import torch
12+
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
1213
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -35,18 +36,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
3536
bn = node
3637
input_node = node.all_input_nodes[0]
3738

38-
if input_node.op != "call_function":
39-
return False
40-
41-
input_name = format_target_name(input_node.target.__name__) # pyre-ignore
42-
43-
if input_name not in ["convolution.default", "linear.default"]:
44-
why(node, f"Invalid input target {input_name.split('.')[0]}")
45-
return False
46-
39+
can_decompose = DecomposeBatchNorm.can_decompose_batch_norm(node, ep, why)
4740
can_fuse = FuseBatchNormPass.can_fuse(input_node, bn, ep)
48-
if not can_fuse:
49-
why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}")
41+
42+
if not can_fuse and not can_decompose:
43+
why(node, f"BatchNorm cannot be decomposed or fused with {input_node}")
5044
return False
5145

5246
return True

0 commit comments

Comments
 (0)