Skip to content

Commit 862df9c

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Fix DecomposeLayerNormPass to handle 6-arg layer_norm (#16516)
Summary: ## Problem When using `nn.LayerNorm` in models that go through the ARM backend's quantization flow, the `DecomposeLayerNormPass` fails with: ``` ValueError: DecomposeLayerNormPass: too many values to unpack (expected 2) ``` This happens because `torch.ops.aten.layer_norm.default` has **6 arguments**: ``` layer_norm(input, normalized_shape, weight, bias, eps, cudnn_enable) ``` But `DecomposeLayerNormPass` only handled up to 5 arguments (for `native_layer_norm`). The error occurs during `transform_for_annotation_pipeline` in the ARM quantizer, which runs before edge transformation when the op is still `aten.layer_norm.default`. ## Solution Add `case 6:` to the `match len(args)` block in `DecomposeLayerNormPass.call()` to handle the 6th argument (`cudnn_enable`). This argument is simply ignored during decomposition since it's only relevant for cuDNN GPU optimization. --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Confucius Session](https://www.internalfb.com/confucius?host=92481.od.fbinfra.net&port=8086&tab=Chat&session_id=eace3d92-ed78-11f0-b67c-c7843469b0d5&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=eace3d92-ed78-11f0-b67c-c7843469b0d5&tab=Trace) Reviewed By: JacobSzwejbka Differential Revision: D90395786
1 parent 88cfb1d commit 862df9c

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def call(self, graph_module: torch.fx.GraphModule):
9090
args = node.args
9191
meta = node.meta
9292
match len(args):
93+
case 6:
94+
# torch.ops.aten.layer_norm.default has 6 args:
95+
# (input, normalized_shape, weight, bias, eps, cudnn_enable)
96+
# cudnn_enable is not used in the decomposition
97+
x, normalized_shape, weights, bias, epsilon, _cudnn_enable = args
9398
case 5:
9499
x, normalized_shape, weights, bias, epsilon = args
95100
case 4:

backends/arm/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def define_arm_tests():
1919
"ops/test_avg_pool2d.py",
2020
"ops/test_cat.py",
2121
"ops/test_conv2d.py",
22-
"ops/test_linear.py",
22+
"ops/test_linear.py",
2323
"ops/test_mul.py",
2424
"ops/test_permute.py",
2525
"ops/test_rsqrt.py",

0 commit comments

Comments
 (0)