Skip to content

Commit 2ae39a0

Browse files
Add layernorm decomposition
- Decompose layernorm - Add unittest for layernorm Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Iccc382898cf247c560ef55c4711fab40d47f04dc
1 parent dcf549f commit 2ae39a0

File tree

7 files changed

+331
-11
lines changed

7 files changed

+331
-11
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.decompose_layernorm_pass import (
23+
DecomposeLayerNormPass,
24+
)
2225
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
2326
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
2427
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
@@ -50,6 +53,7 @@ def transform_to_backend_pipeline(
5053
self.add_pass(SizeAdjustConv2DPass())
5154
self.add_pass(RemoveClonePass())
5255
self.add_pass(ConvertExpandCopyToRepeatPass())
56+
self.add_pass(DecomposeLayerNormPass())
5357
self.add_pass(DecomposeVarPass())
5458
self.add_pass(ConvertMeanDimToAveragePool())
5559
self.add_pass(DecomposeMeanDimPass())
@@ -65,6 +69,7 @@ def transform_to_backend_pipeline(
6569
return self._transform(exported_program.graph_module)
6670

6771
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
72+
self.add_pass(DecomposeLayerNormPass())
6873
self.add_pass(DecomposeVarPass())
6974
self.add_pass(DecomposeMeanDimPass())
7075
self.add_pass(DecomposeDivPass())
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
def get_layer_norm_decomposition(op) -> tuple:
16+
if op == exir_ops.edge.aten.native_layer_norm.default:
17+
return (
18+
exir_ops.edge.aten.mean.dim,
19+
exir_ops.edge.aten.sub.Tensor,
20+
exir_ops.edge.aten.var.correction,
21+
exir_ops.edge.aten.full.default,
22+
exir_ops.edge.aten.add.Tensor,
23+
exir_ops.edge.aten.rsqrt.default,
24+
exir_ops.edge.aten.mul.Tensor,
25+
exir_ops.edge.aten.view_copy.default,
26+
)
27+
if op == torch.ops.aten.layer_norm.default:
28+
return (
29+
torch.ops.aten.mean.dim,
30+
torch.ops.aten.sub.Tensor,
31+
torch.ops.aten.var.correction,
32+
torch.ops.aten.full.default,
33+
torch.ops.aten.add.Tensor,
34+
torch.ops.aten.rsqrt.default,
35+
torch.ops.aten.mul.Tensor,
36+
torch.ops.aten.view_copy.default,
37+
)
38+
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
39+
40+
41+
class DecomposeLayerNormPass(ExportPass):
42+
"""
43+
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
44+
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
45+
mean = op_mean(x, dims) # E[x]
46+
var = op_var(x, dims) # Var[x]
47+
denominator = op_sub(x, mean) # (x - E[x])
48+
add = op_add(var, eps) # Var[x] + eps
49+
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
50+
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
51+
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
52+
53+
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
54+
"""
55+
56+
def call(self, gm: torch.fx.GraphModule):
57+
for node in gm.graph.nodes:
58+
if node.op != "call_function" or node.target not in (
59+
exir_ops.edge.aten.native_layer_norm.default,
60+
torch.ops.aten.layer_norm.default,
61+
):
62+
continue
63+
64+
# epsilon default value
65+
epsilon = 1e-5
66+
weights = None
67+
bias = None
68+
args = node.args
69+
meta = node.meta
70+
match len(args):
71+
case 5:
72+
x, normalized_shape, weights, bias, epsilon = args
73+
case 4:
74+
x, normalized_shape, weights, bias = args
75+
case 3:
76+
x, normalized_shape, weights = args
77+
case 2:
78+
x, normalized_shape = args
79+
80+
n_dims = len(normalized_shape)
81+
if isinstance(meta["val"], tuple):
82+
shape = meta["val"][0].size()
83+
else:
84+
shape = meta["val"].size()
85+
dtype = meta["val"][0].dtype
86+
rank = len(shape)
87+
dims = list(range(-1, -1 * (n_dims + 1), -1))
88+
dims = [dim % rank for dim in dims]
89+
weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)]
90+
epsilon_reshaped_shape = [1] * rank
91+
92+
(
93+
mean_op,
94+
sub_op,
95+
var_op,
96+
full_op,
97+
add_op,
98+
rsqrt_op,
99+
mul_op,
100+
view_op,
101+
) = get_layer_norm_decomposition(node.target)
102+
with gm.graph.inserting_before(node):
103+
keepdim = True
104+
mean = create_node(gm.graph, mean_op, args=(x, dims, keepdim))
105+
sub = create_node(gm.graph, sub_op, args=(x, mean))
106+
var = create_node(
107+
gm.graph,
108+
var_op,
109+
args=(x, dims),
110+
kwargs={"correction": 0, "keepdim": keepdim},
111+
)
112+
full = create_node(
113+
gm.graph,
114+
full_op,
115+
args=(epsilon_reshaped_shape, epsilon),
116+
kwargs={"dtype": dtype},
117+
)
118+
add0 = create_node(gm.graph, add_op, args=(var, full))
119+
rsqrt = create_node(gm.graph, rsqrt_op, args=(add0,))
120+
mul0 = create_node(gm.graph, mul_op, args=(sub, rsqrt))
121+
if weights is not None:
122+
weights_reshaped = create_node(
123+
gm.graph, view_op, args=(weights, weights_reshaped_shape)
124+
)
125+
mul1 = create_node(gm.graph, mul_op, args=(mul0, weights_reshaped))
126+
else:
127+
mul1 = mul0
128+
output = mul1
129+
if bias is not None:
130+
bias_reshaped_shape = weights_reshaped_shape
131+
bias_reshaped = create_node(
132+
gm.graph, view_op, args=(bias, bias_reshaped_shape)
133+
)
134+
output = create_node(gm.graph, add_op, args=(mul1, bias_reshaped))
135+
136+
users = [user for user in node.users if node != user]
137+
node.replace_all_uses_with(output)
138+
for user in users:
139+
if user.target == operator.getitem:
140+
user.replace_all_uses_with(output)
141+
gm.graph.erase_node(node)
142+
gm.graph.eliminate_dead_code()
143+
gm.recompile()
144+
gm = super().call(gm).graph_module
145+
146+
return PassResult(gm, True)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.mul.Tensor,
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
56+
exir_ops.edge.aten.native_layer_norm.default,
5657
exir_ops.edge.aten.avg_pool2d.default,
5758
exir_ops.edge.aten.sigmoid.default,
5859
exir_ops.edge.aten.mm.default,

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,13 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
149149
torch.ops.aten.squeeze.default,
150150
torch.ops.aten.squeeze_copy.dim,
151151
torch.ops.aten.unsqueeze.default,
152+
torch.ops.aten.unsqueeze_copy.default,
152153
# TODO: remove?
153154
torch.ops.aten.adaptive_avg_pool2d.default,
154155
torch.ops.aten.avg_pool2d.default,
155156
torch.ops.aten.view_copy.default,
156157
torch.ops.aten.view.default,
158+
torch.ops.aten.full.default,
157159
torch.ops.aten.slice.Tensor,
158160
torch.ops.aten.split.Tensor,
159161
torch.ops.aten.split_with_sizes.default,

backends/arm/quantizer/quantization_annotation/add_annotator.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
# pyre-unsafe
88

9-
import itertools
10-
import operator
119
from typing import Callable, List, Optional
1210

1311
import torch
@@ -16,7 +14,6 @@
1614
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1715
from torch.ao.quantization.quantizer import QuantizationAnnotation
1816
from torch.fx import Node
19-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2017

2118

2219
@register_annotator("add")
@@ -25,14 +22,12 @@ def _annotate_add(
2522
quantization_config: QuantizationConfig,
2623
filter_fn: Optional[Callable[[Node], bool]] = None,
2724
) -> Optional[List[List[Node]]]:
28-
add_partitions = get_source_partitions(
29-
gm.graph, [operator.add, torch.add, operator.iadd], filter_fn
30-
)
31-
add_partitions = list(itertools.chain.from_iterable(add_partitions.values()))
3225
annotated_partitions = []
33-
for add_partition in add_partitions:
34-
annotated_partitions.append(add_partition.nodes)
35-
add_node = add_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.add.Tensor,):
28+
continue
29+
annotated_partitions.append(node)
30+
add_node = node
3631
if arm_quantizer_utils.is_annotated(add_node):
3732
continue
3833

0 commit comments

Comments
 (0)