Skip to content

Commit d05a793

Browse files
authored
Fix batch norm partitioning with Conv3d (pytorch#13696)
Summary: Models with a batch norm following a conv3d cause an internal error during lowering. This diff fixes it by updating the partitioning logic to only rely on fusion with 1d and 2d convs. This is because XNNPACK doesn't currently support standalone batch norms and only partitions norms that can be fused. We can't fuse with Conv3d, because XNNPACK doesn't have an implementation. The partitioner constraint was missing logic to exclude Conv3d. Differential Revision: D81069236
1 parent de0ff26 commit d05a793

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

backends/xnnpack/_passes/fuse_batch_norm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,16 @@ def call(self, graph_module: torch.fx.GraphModule):
8282

8383
@staticmethod
8484
def can_fuse(
85-
input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
85+
input_node: torch.fx.Node,
86+
bn: torch.fx.Node,
87+
program: ExportedProgram,
8688
) -> bool:
8789
"""
8890
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
8991
"""
9092

93+
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
94+
9195
# All users of the batch_norm node must be getitem ops.
9296
# batch_norm returns a 3-element tuple.
9397
# Each user must only access the first element of the tuple.
@@ -110,6 +114,10 @@ def can_fuse(
110114
].count(False):
111115
return False
112116

117+
# Check the rank of the convolutution input - only Conv1d and 2d are supported.
118+
if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4):
119+
return False
120+
113121
return True
114122

115123
def _fuse_ops(

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ def forward(self, x):
5555
y = y + y
5656
return self.bn(y)
5757

58+
class ModelConv3dBN(torch.nn.Module):
59+
def __init__(
60+
self,
61+
in_features: int,
62+
out_features: int,
63+
kernel_size: Tuple[int, int],
64+
):
65+
super().__init__()
66+
op = torch.nn.Conv3d
67+
self.conv3d = op(in_features, out_features, kernel_size)
68+
self.bn = torch.nn.BatchNorm3d(out_features)
69+
self.forward(torch.randn(2, 2, 4, 4, 4) * 2 + 2) # update the BN stats
70+
71+
def forward(self, x):
72+
y = self.conv3d(x)
73+
y = self.bn(y)
74+
y = self.conv3d(y)
75+
y = y + y
76+
return self.bn(y)
77+
5878
def test_fp32_conv_batch_norm_fusion(self):
5979
for transpose in [False, True]:
6080
(
@@ -142,3 +162,18 @@ def forward(self, x):
142162
.to_edge_transform_and_lower()
143163
.check_count({self.bn_name: 1})
144164
)
165+
166+
def test_fp32_conv3d_batch_norm_doesnt_partition(self):
167+
"""
168+
Conv3d is not currently supported by XNNPACK. We also don't support standalone
169+
batch norms yet (i.e. batch norms that are not fused with a conv). As such, we don't
170+
want to partition the standalone batch norm and then fail to lower.
171+
"""
172+
(
173+
Tester(self.ModelConv3dBN(2, 2, (2, 2, 2)), (torch.randn(2, 2, 4, 4, 4),))
174+
.export()
175+
.dump_artifact()
176+
.to_edge_transform_and_lower()
177+
.check_count({self.bn_name: 2})
178+
.run_method_and_compare_outputs()
179+
)

0 commit comments

Comments
 (0)