Skip to content

Commit cc93676

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Fix DecomposeLayerNormPass to handle 6-arg layer_norm
Summary: ## Problem When using `nn.LayerNorm` in models that go through modai_sdk's `post_train_quantize` flow, the `DecomposeLayerNormPass` fails with: ``` ValueError: DecomposeLayerNormPass: too many values to unpack (expected 2) ``` This happens because `torch.ops.aten.layer_norm.default` has **6 arguments**: ``` layer_norm(input, normalized_shape, weight, bias, eps, cudnn_enable) ``` But `DecomposeLayerNormPass` only handled up to 5 arguments (for `native_layer_norm`). The error occurs during `transform_for_annotation_pipeline` in the ARM quantizer, which runs before edge transformation when the op is still `aten.layer_norm.default`. ## Solution Add `case 6:` to the `match len(args)` block in `DecomposeLayerNormPass.call()` to handle the 6th argument (`cudnn_enable`). This argument is simply ignored during decomposition since it's only relevant for cuDNN GPU optimization. ## Testing Added a new test file `test_layernorm_modai_compat.py` that: 1. Creates a simple Linear -> LayerNorm -> Linear model 2. Exports it via `torch.export` 3. Runs it through `transform_for_annotation_pipeline` (the exact path that was failing) 4. Verifies LayerNorm is decomposed correctly through the full TOSA pipelines --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Confucius Session](https://www.internalfb.com/confucius?host=92481.od.fbinfra.net&port=8086&tab=Chat&session_id=eace3d92-ed78-11f0-b67c-c7843469b0d5&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=eace3d92-ed78-11f0-b67c-c7843469b0d5&tab=Trace) Differential Revision: D90395786
1 parent 5b4900c commit cc93676

File tree

3 files changed

+180
-1
lines changed

3 files changed

+180
-1
lines changed

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def call(self, graph_module: torch.fx.GraphModule):
9090
args = node.args
9191
meta = node.meta
9292
match len(args):
93+
case 6:
94+
# torch.ops.aten.layer_norm.default has 6 args:
95+
# (input, normalized_shape, weight, bias, eps, cudnn_enable)
96+
# cudnn_enable is not used in the decomposition
97+
x, normalized_shape, weights, bias, epsilon, _cudnn_enable = args
9398
case 5:
9499
x, normalized_shape, weights, bias, epsilon = args
95100
case 4:
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
Test LayerNorm compatibility with modai_sdk lowering pipeline.
8+
9+
This test verifies that nn.LayerNorm works correctly through the ARM backend's
10+
transform_for_annotation pipeline, which is used by modai_sdk during PTQ.
11+
12+
The key issue was that torch.ops.aten.layer_norm.default has 6 arguments:
13+
(input, normalized_shape, weight, bias, eps, cudnn_enable)
14+
15+
But DecomposeLayerNormPass only handled up to 5 args, causing a ValueError
16+
when the 6th arg (cudnn_enable) was present.
17+
18+
Related: D88489694, T247846380
19+
"""
20+
21+
from typing import Tuple
22+
23+
import torch
24+
import torch.nn as nn
25+
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
26+
from executorch.backends.arm.test import common
27+
from executorch.backends.arm.test.tester.test_pipeline import (
28+
EthosU55PipelineINT,
29+
TosaPipelineFP,
30+
TosaPipelineINT,
31+
)
32+
from executorch.backends.arm.tosa.specification import TosaSpecification
33+
34+
input_t = Tuple[torch.Tensor]
35+
36+
37+
class SimpleLayerNormModel(nn.Module):
38+
"""Simple model: Linear -> LayerNorm -> Linear"""
39+
40+
def __init__(self, hidden_dim: int = 32):
41+
super().__init__()
42+
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
43+
self.layer_norm = nn.LayerNorm(hidden_dim)
44+
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
45+
46+
def forward(self, x: torch.Tensor) -> torch.Tensor:
47+
x = self.linear1(x)
48+
x = self.layer_norm(x)
49+
x = self.linear2(x)
50+
return x
51+
52+
53+
class LayerNormWithPermute(nn.Module):
54+
"""
55+
LayerNorm with permute pattern (channels first).
56+
57+
This pattern is common in models like HRNet where the data is in
58+
(B, C, H, L) format and needs to be permuted for LayerNorm.
59+
"""
60+
61+
def __init__(self, num_channels: int = 16):
62+
super().__init__()
63+
self.layer_norm = nn.LayerNorm(num_channels)
64+
65+
def forward(self, x: torch.Tensor) -> torch.Tensor:
66+
# x is (B, C, H, L) - channels first
67+
x = torch.permute(x, (0, 2, 3, 1)) # (B, C, H, L) -> (B, H, L, C)
68+
x = self.layer_norm(x)
69+
x = torch.permute(x, (0, 3, 1, 2)) # (B, H, L, C) -> (B, C, H, L)
70+
return x
71+
72+
73+
test_data_suite = {
74+
"simple_1d": lambda: (
75+
(torch.randn(1, 32),),
76+
SimpleLayerNormModel(hidden_dim=32),
77+
),
78+
"simple_2d": lambda: (
79+
(torch.randn(1, 10, 32),),
80+
SimpleLayerNormModel(hidden_dim=32),
81+
),
82+
"with_permute": lambda: (
83+
(torch.randn(1, 16, 4, 8),),
84+
LayerNormWithPermute(num_channels=16),
85+
),
86+
}
87+
88+
89+
def test_layernorm_transform_for_annotation():
90+
"""
91+
Test that LayerNorm works through transform_for_annotation pipeline.
92+
93+
This is the key test - it directly tests the pipeline that was failing
94+
in modai_sdk when DecomposeLayerNormPass couldn't handle 6 args.
95+
"""
96+
model = SimpleLayerNormModel(hidden_dim=16).eval()
97+
sample_input = (torch.randn(1, 16),)
98+
99+
# Export the model
100+
exported_program = torch.export.export(model, sample_input)
101+
graph_module = exported_program.graph_module
102+
103+
# Debug: Print out what layer_norm nodes look like
104+
print("\n=== Exported graph nodes ===")
105+
for node in graph_module.graph.nodes:
106+
if "layer_norm" in str(node.target):
107+
print(f"Node: {node.name}")
108+
print(f" Target: {node.target}")
109+
print(f" Args count: {len(node.args)}")
110+
print(f" Args: {node.args}")
111+
print(f" Kwargs: {node.kwargs}")
112+
print("=== End of layer_norm nodes ===\n")
113+
114+
# Create ArmPassManager with proper compile spec (similar to what modai_sdk does)
115+
# ArmPassManager expects an ArmCompileSpec, not TosaSpecification directly
116+
from executorch.backends.arm.test import common as ethos_common
117+
118+
compile_spec = ethos_common.get_tosa_compile_spec("TOSA-1.00+INT+FP")
119+
pass_manager = ArmPassManager(compile_spec)
120+
121+
# This is the call that was failing before the fix
122+
# It runs DecomposeLayerNormPass among other passes
123+
try:
124+
result = pass_manager.transform_for_annotation_pipeline(
125+
graph_module=graph_module
126+
)
127+
assert result is not None, "transform_for_annotation_pipeline returned None"
128+
except ValueError as e:
129+
if "too many values to unpack" in str(e):
130+
raise AssertionError(
131+
f"DecomposeLayerNormPass failed to handle layer_norm args: {e}"
132+
) from e
133+
raise
134+
135+
136+
@common.parametrize("test_data", test_data_suite)
137+
def test_layernorm_tosa_FP(test_data):
138+
"""Test LayerNorm in TOSA FP pipeline."""
139+
test_data, model = test_data()
140+
pipeline = TosaPipelineFP[input_t](
141+
model,
142+
test_data,
143+
"torch.ops.aten.layer_norm.default",
144+
)
145+
pipeline.run()
146+
147+
148+
@common.parametrize("test_data", test_data_suite)
149+
def test_layernorm_tosa_INT(test_data):
150+
"""Test LayerNorm in TOSA INT (quantized) pipeline."""
151+
test_data, model = test_data()
152+
pipeline = TosaPipelineINT[input_t](
153+
model,
154+
test_data,
155+
# After decomposition, check for sub op which is part of layernorm
156+
"torch.ops.aten.sub.Tensor",
157+
symmetric_io_quantization=True,
158+
)
159+
pipeline.run()
160+
161+
162+
@common.parametrize("test_data", test_data_suite)
163+
@common.XfailIfNoCorstone300
164+
def test_layernorm_u55_INT(test_data):
165+
"""Test LayerNorm in Ethos U55 INT pipeline."""
166+
test_data, model = test_data()
167+
pipeline = EthosU55PipelineINT[input_t](
168+
model,
169+
test_data,
170+
"torch.ops.aten.sub.Tensor",
171+
symmetric_io_quantization=True,
172+
)
173+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def define_arm_tests():
1919
"ops/test_avg_pool2d.py",
2020
"ops/test_cat.py",
2121
"ops/test_conv2d.py",
22-
"ops/test_linear.py",
22+
"ops/test_layernorm_modai_compat.py",
23+
"ops/test_linear.py",
2324
"ops/test_mul.py",
2425
"ops/test_permute.py",
2526
"ops/test_rsqrt.py",

0 commit comments

Comments
 (0)