Skip to content

Commit ddd4c1e

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Support un-fused batchnorm1d/2d on XNNPACK via decomposition
Summary: Add a new pass - DecomposeBatchNorm - which converts standalone (non-fused) batch norm operators to 1x1 depthwise convolution. This prevents delegation graph breaks when batch norm operators can't be fused. Differential Revision: D90422630
1 parent 236f847 commit ddd4c1e

File tree

7 files changed

+641
-16
lines changed

7 files changed

+641
-16
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
2424
ConvertToUpsampleBilinear2d,
2525
)
26+
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
2627
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2728
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
2829
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
@@ -76,6 +77,7 @@ def __init__(
7677
ConvertToSDPAPass,
7778
ConstPropPass,
7879
FuseBatchNormPass,
80+
DecomposeBatchNorm,
7981
FuseActivationPass,
8082
DecomposeConcatenate,
8183
RemoveGetItemPass,
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import operator
9+
10+
import torch
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.backends.xnnpack.utils.utils import (
13+
check_or_raise,
14+
get_param_tensor,
15+
is_param_node,
16+
)
17+
from executorch.exir.backend.utils import WhyNoPartition
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from torch.fx.passes.infra.pass_base import PassResult
20+
21+
logger = logging.getLogger(__name__)
22+
logger.setLevel(logging.WARNING)
23+
24+
25+
class DecomposeBatchNorm(XNNPACKPass):
26+
"""
27+
Decompose batchnorm operators into 1x1 depthwise convolution.
28+
"""
29+
30+
BATCH_NORM_OPS = {
31+
exir_ops.edge.aten.native_batch_norm.default,
32+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
33+
}
34+
35+
@staticmethod
36+
def can_decompose_batch_norm(
37+
node: torch.fx.Node,
38+
exported_program: torch.export.ExportedProgram,
39+
why: WhyNoPartition | None = None,
40+
) -> bool:
41+
"""
42+
Determine whether the given batch norm node can be decomposed by this pass.
43+
"""
44+
45+
if node.op != "call_function" or node.target not in DecomposeBatchNorm.BATCH_NORM_OPS:
46+
return False
47+
48+
input_meta = node.args[0].meta["val"]
49+
50+
# Since we're converting to conv and XNNPACK doesn't support conv3d, we can't
51+
# handle BatchNorm3d. Validate the input dimension. We'll take NC, NCL, or NCHW.
52+
if input_meta.dim() not in (2, 3, 4):
53+
if why:
54+
why(node, f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.")
55+
return False
56+
57+
# The batch norm node returns a tuple of output and other stuff we don't care about.
58+
# All users must be getitem nodes that fetches the output (index 0).
59+
# The partitioner should enforce this, but we'll check it here too.
60+
for user in node.users:
61+
if user.target != operator.getitem or user.args[1] != 0:
62+
if why:
63+
why(node, "Batch norm users must only access the output tensor.")
64+
return False
65+
66+
# Channel dimension and non-input args must be statically known.
67+
if not isinstance(input_meta.shape[1], int):
68+
if why:
69+
why(node, f"Channel dimension must be statically known, but was {input_meta.shape[1]}.")
70+
return False
71+
72+
if not is_param_node(exported_program, node.args[1]) or not is_param_node(exported_program, node.args[2]):
73+
if why:
74+
why(node, "Batch norm affine weight and bias must be static.")
75+
return False
76+
77+
if not is_param_node(exported_program, node.args[3]) or not is_param_node(exported_program, node.args[4]):
78+
if why:
79+
why(node, "Batch norm running mean and variance must be static.")
80+
return False
81+
82+
if isinstance(node.args[-1], torch.fx.Node):
83+
if why:
84+
why(node, "Batch norm epsilon must be static.")
85+
return False
86+
87+
return True
88+
89+
@staticmethod
90+
def compute_w_and_b(
91+
eps: float,
92+
running_mean: torch.Tensor, # [C]
93+
running_var: torch.Tensor, # [C]
94+
gamma: torch.Tensor, # [C], learned weight
95+
beta: torch.Tensor, # [C], learned bias
96+
) -> (torch.Tensor, torch.Tensor):
97+
"""
98+
Compute equivalent per-channel weight and bias to match the batch norm
99+
computation with frozen values.
100+
"""
101+
102+
# See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
103+
denom = torch.sqrt(running_var + torch.Tensor([eps]))
104+
weight = gamma / denom
105+
bias = -running_mean * gamma / denom + beta
106+
107+
return weight, bias
108+
109+
@staticmethod
110+
def replace_bn_node_with_conv(
111+
bn_node: torch.fx.Node,
112+
graph_module: torch.fx.GraphModule,
113+
exported_program: torch.export.ExportedProgram,
114+
) -> torch.fx.Node:
115+
"""
116+
Replace a BatchNorm with NCL or NCHW input with an equivalent depthwise
117+
convolution.
118+
"""
119+
120+
# Compute the equivalent per-channel weights and biases.
121+
# Note that the batch norm node args are
122+
# (input, gamma, beta, running_mean, running_var, [training], momentum, eps).
123+
# The training arg is not present in the _no_training variant.
124+
weight, bias = DecomposeBatchNorm.compute_w_and_b(
125+
eps=bn_node.args[-1],
126+
running_mean=get_param_tensor(exported_program, bn_node.args[3]),
127+
running_var=get_param_tensor(exported_program, bn_node.args[4]),
128+
gamma=get_param_tensor(exported_program, bn_node.args[1]),
129+
beta=get_param_tensor(exported_program, bn_node.args[2]),
130+
)
131+
132+
with graph_module.graph.inserting_after(bn_node):
133+
# Conv weights have shape [out_c, in_c/g, spatial...].
134+
# For dw, in_c = g. The kernel is also 1x1 (or just 1, for 1d).
135+
#
136+
# BatchNorm weights have shape [in_c].
137+
# So we just need to unsqueeze the [in_c] to to [in_c, 1, 1, [1]].
138+
input_meta = bn_node.args[0].meta["val"]
139+
channel_count = input_meta.shape[1]
140+
spatial_dims = max(input_meta.dim() - 2, 1) # Min of 1 since 1d can be NC or NCL.
141+
new_weight_shape = [weight.shape[0], 1] + [1] * spatial_dims
142+
weight = weight.reshape(new_weight_shape)
143+
144+
# Insert the new weight and biases in the graph.
145+
# TODO?
146+
147+
conv_node = graph_module.graph.call_function(
148+
exir_ops.edge.aten.convolution.default,
149+
args=(
150+
bn_node.args[0], # Input
151+
weight, # Weight
152+
bias, # Bias
153+
[1] * spatial_dims, # Stride
154+
[0] * spatial_dims, # Padding
155+
[1] * spatial_dims, # Dilation
156+
False, # Transposed
157+
[0] * spatial_dims, # Output_padding
158+
channel_count, # Groups (depthwise, so groups=in_channels)
159+
))
160+
161+
# Find the getitem user nodes and replace them with the conv node.
162+
# The decomp checks above enforce that the node is only used by getitem[0].
163+
users = list(bn_node.users)
164+
for user in users:
165+
user.replace_all_uses_with(conv_node)
166+
graph_module.graph.erase_node(user)
167+
168+
graph_module.graph.erase_node(bn_node)
169+
return conv_node
170+
171+
172+
def decompose_node(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None:
173+
input_meta = node.args[0].meta["val"]
174+
175+
# These should be checked by the partitioner and calling node,
176+
# so we should never fail these checks.
177+
check_or_raise(
178+
node.op == "call_function" and node.target in DecomposeBatchNorm.BATCH_NORM_OPS,
179+
f"Invalid batch norm operator {node.op}.")
180+
181+
check_or_raise(
182+
input_meta.dim() in (2, 3, 4),
183+
f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.")
184+
185+
channel_count = input_meta.shape[1]
186+
check_or_raise(
187+
isinstance(channel_count, int),
188+
f"Channel dimension must be statically known, but was {channel_count}.")
189+
190+
# Create the convolution node.
191+
conv_node = self.replace_bn_node_with_conv(node, graph_module, self.exported_program)
192+
193+
# BatchNorm1d can be NC or NCL. Conv1d requies the L dim, so unsqueeze NC -> NCL.
194+
if input_meta.dim() == 2:
195+
with graph_module.graph.inserting_before(conv_node):
196+
# Insert unsqueeze node before.
197+
unsqueeze_node = graph_module.graph.call_function(
198+
exir_ops.edge.aten.unsqueeze_copy.default,
199+
args=(conv_node.args[0], 2))
200+
conv_node.args = (unsqueeze_node, *conv_node.args[1:])
201+
202+
with graph_module.graph.inserting_after(conv_node):
203+
# Insert squeeze node after.
204+
squeeze_node = graph_module.graph.call_function(
205+
exir_ops.edge.aten.squeeze_copy.dim,
206+
args=(conv_node, 2))
207+
conv_node.replace_all_uses_with(squeeze_node)
208+
# This gets overwritten by replace_all_uses_with. Maybe there's
209+
# a better solution?
210+
squeeze_node.args = (conv_node, *squeeze_node.args[1:])
211+
212+
# override
213+
def call(self, graph_module: torch.fx.GraphModule):
214+
# fall back to linear transform
215+
for node in graph_module.graph.nodes:
216+
if node.op == "call_function" and node.target in self.BATCH_NORM_OPS:
217+
if self.can_decompose_batch_norm(node, self.exported_program):
218+
self.decompose_node(node, graph_module)
219+
220+
graph_module.print_readable()
221+
graph_module.recompile()
222+
223+
# Propagate metadata and retrace module
224+
graph_module = super().call(graph_module).graph_module
225+
226+
return PassResult(graph_module, True)

backends/xnnpack/_passes/fuse_batch_norm.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,17 @@
1111
create_constant_placeholder,
1212
delete_constant_placeholder,
1313
)
14-
1514
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16-
1715
from executorch.backends.xnnpack.utils.utils import (
1816
get_param_tensor,
1917
get_tensor_name,
2018
is_param_node,
2119
)
2220
from executorch.exir import ExportedProgram
21+
from executorch.exir.backend.utils import WhyNoPartition
2322
from executorch.exir.dialects._ops import ops as exir_ops
2423
from executorch.exir.pass_base import PassResult
2524
from torch.export.graph_signature import InputKind
26-
2725
from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights
2826

2927

@@ -85,11 +83,23 @@ def can_fuse(
8583
input_node: torch.fx.Node,
8684
bn: torch.fx.Node,
8785
program: ExportedProgram,
86+
why: WhyNoPartition | None = None,
8887
) -> bool:
8988
"""
9089
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
9190
"""
9291

92+
if input_node.op != "call_function":
93+
return False
94+
95+
if input_node.target not in (
96+
exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.linear.default
97+
):
98+
if why:
99+
why("Input node must be a convolution or linear op.")
100+
return False
101+
102+
93103
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
94104

95105
# All users of the batch_norm node must be getitem ops.
@@ -98,6 +108,8 @@ def can_fuse(
98108
if [
99109
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
100110
].count(False):
111+
if why:
112+
why("Batch norm users must only access the output tensor.")
101113
return False
102114

103115
input_node_weights = input_node.args[1]
@@ -107,11 +119,15 @@ def can_fuse(
107119
if not isinstance(input_node_weights, torch.fx.Node) or not isinstance(
108120
bn_weights, torch.fx.Node
109121
):
122+
if why:
123+
why("Input node weights must be parameters.")
110124
return False
111125

112126
if [
113127
is_param_node(program, node) for node in {input_node_weights, bn_weights}
114128
].count(False):
129+
if why:
130+
why("Node weights must be static.")
115131
return False
116132

117133
# Check the rank of the convolutution input - only Conv1d and 2d are supported.
@@ -122,6 +138,8 @@ def can_fuse(
122138
or "val" not in conv_input.meta
123139
or len(conv_input.meta["val"].shape) not in (3, 4)
124140
):
141+
if why:
142+
why("Convolution input must be rank 3 or 4.")
125143
return False
126144

127145
return True

backends/xnnpack/operators/op_squeeze.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def define_node(
3636
debug_handle: int,
3737
) -> None:
3838

39+
dim = cast(int, node.args[1])
3940
check_or_raise(
40-
cast(int, node.args[1]) == -1,
41+
dim == -1 or dim == len(node.args[0].meta["val"].shape) - 1,
4142
"XNNPACK currently only supports squeezing in last dimension",
4243
)
4344

@@ -98,8 +99,9 @@ def define_node(
9899
debug_handle: int,
99100
) -> None:
100101

102+
dim = cast(int, node.args[1])
101103
check_or_raise(
102-
cast(int, node.args[1]) == -1,
104+
dim == -1 or dim == len(node.args[0].meta["val"].shape),
103105
"XNNPACK currently only supports unsqueezing in last dimension",
104106
)
105107

backends/xnnpack/partition/config/node_configs.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import List, Optional
1010

1111
import torch
12+
from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm
1213
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -35,18 +36,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
3536
bn = node
3637
input_node = node.all_input_nodes[0]
3738

38-
if input_node.op != "call_function":
39-
return False
40-
41-
input_name = format_target_name(input_node.target.__name__) # pyre-ignore
42-
43-
if input_name not in ["convolution.default", "linear.default"]:
44-
why(node, f"Invalid input target {input_name.split('.')[0]}")
45-
return False
46-
39+
can_decompose = DecomposeBatchNorm.can_decompose_batch_norm(node, ep, why)
4740
can_fuse = FuseBatchNormPass.can_fuse(input_node, bn, ep)
48-
if not can_fuse:
49-
why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}")
41+
42+
if not can_fuse and not can_decompose:
43+
why(node, f"BatchNorm cannot be decomposed or fused with {input_node}")
5044
return False
5145

5246
return True

0 commit comments

Comments
 (0)