Skip to content

Commit 07c6752

Browse files
authored
Forward fix Fix batch norm partitioning with Conv3d (#13696) (#14170)
Summary: Forward pyre fix https://www.internalfb.com/diff/D81069236 #13696 bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: digantdesai Differential Revision: D82118651
1 parent 9b6387f commit 07c6752

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

backends/xnnpack/_passes/fuse_batch_norm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def can_fuse(
115115
return False
116116

117117
# Check the rank of the convolutution input - only Conv1d and 2d are supported.
118-
if is_conv and len(input_node.args[0].meta["val"].shape) not in (3, 4):
119-
return False
118+
if is_conv:
119+
conv_input = input_node.args[0]
120+
if (
121+
not isinstance(conv_input, torch.fx.Node)
122+
or "val" not in conv_input.meta
123+
or len(conv_input.meta["val"].shape) not in (3, 4)
124+
):
125+
return False
120126

121127
return True
122128

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
self,
6161
in_features: int,
6262
out_features: int,
63-
kernel_size: Tuple[int, int],
63+
kernel_size: Tuple[int, int, int],
6464
):
6565
super().__init__()
6666
op = torch.nn.Conv3d

0 commit comments

Comments
 (0)