Skip to content

Commit 0d78c23

Browse files
authored
Fix DecomposeLayerNormPass to handle 6-arg layer_norm
Differential Revision: D90395786 Pull Request resolved: #16516
1 parent 806c8e8 commit 0d78c23

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def call(self, graph_module: torch.fx.GraphModule):
9090
args = node.args
9191
meta = node.meta
9292
match len(args):
93+
case 6:
94+
# torch.ops.aten.layer_norm.default has 6 args:
95+
# (input, normalized_shape, weight, bias, eps, cudnn_enable)
96+
# cudnn_enable is not used in the decomposition
97+
x, normalized_shape, weights, bias, epsilon, _cudnn_enable = args
9398
case 5:
9499
x, normalized_shape, weights, bias, epsilon = args
95100
case 4:

backends/arm/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def define_arm_tests():
1919
"ops/test_avg_pool2d.py",
2020
"ops/test_cat.py",
2121
"ops/test_conv2d.py",
22-
"ops/test_linear.py",
22+
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
2525
"ops/test_rsqrt.py",

0 commit comments

Comments
 (0)