diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index 74c01442c9e..5ebb7e92dad 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -90,6 +90,11 @@ def call(self, graph_module: torch.fx.GraphModule): args = node.args meta = node.meta match len(args): + case 6: + # torch.ops.aten.layer_norm.default has 6 args: + # (input, normalized_shape, weight, bias, eps, cudnn_enable) + # cudnn_enable is not used in the decomposition + x, normalized_shape, weights, bias, epsilon, _cudnn_enable = args case 5: x, normalized_shape, weights, bias, epsilon = args case 4: diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 8715ea80a14..14b4a37e346 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -19,7 +19,7 @@ def define_arm_tests(): "ops/test_avg_pool2d.py", "ops/test_cat.py", "ops/test_conv2d.py", - "ops/test_linear.py", + "ops/test_linear.py", "ops/test_mul.py", "ops/test_permute.py", "ops/test_rsqrt.py",