Skip to content

Commit 7535720

Browse files
authored
Arm backend: Add GLU decomposition pass and test (#13270)
Decomposes the gated linear unit function. Signed-off-by: Teo Bergkvist <[email protected]>
1 parent 3067e98 commit 7535720

File tree

5 files changed

+211
-0
lines changed

5 files changed

+211
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .decompose_div_pass import DecomposeDivPass # noqa
3737
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3838
from .decompose_gelu_pass import DecomposeGeluPass # noqa
39+
from .decompose_glu_pass import DecomposeGluPass # noqa
3940
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4041
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
4142
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DecomposeDivPass,
4242
DecomposeEmbeddingPass,
4343
DecomposeGeluPass,
44+
DecomposeGluPass,
4445
DecomposeGroupedConv,
4546
DecomposeGroupNormPass,
4647
DecomposeLayerNormPass,
@@ -184,6 +185,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
184185
self.add_pass(ConvertSplitToSlicePass())
185186
self.add_pass(FuseBatchnorm2DPass(exported_program))
186187
self.add_pass(ConvertMmToBmmPass())
188+
self.add_pass(DecomposeGluPass())
187189
self.add_pass(DecomposeLinearPass())
188190
self.add_pass(DecomposeLeakyReLUPass())
189191
self.add_pass(DecomposeGroupNormPass())
@@ -264,6 +266,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
264266
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
265267
self.add_pass(DecomposeNotEqualPass())
266268
self.add_pass(DecomposeCosineSimilarityPass())
269+
self.add_pass(DecomposeGluPass())
267270
self.add_pass(DecomposeDivPass())
268271
self.add_pass(DecomposeLeakyReLUPass())
269272
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
11+
# For FP case
12+
edge_glu = exir_ops.edge.aten.glu.default
13+
14+
# For INT case
15+
aten_glu = torch.ops.aten.glu.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_glu:
21+
return (
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.sigmoid.default,
24+
exir_ops.edge.aten.slice_copy.Tensor,
25+
)
26+
elif op == aten_glu:
27+
return (
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.sigmoid.default,
30+
torch.ops.aten.slice_copy.Tensor,
31+
)
32+
else:
33+
raise ValueError(f"Unsupported operator: {op}")
34+
35+
36+
class DecomposeGluPass(ArmPass):
37+
"""Decomposes the GLU operator into hadamard product and sigmoid."""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in [edge_glu, aten_glu]:
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
hadamard_prod, sigmoid, slice_op = get_ops(op)
44+
X = args[0]
45+
46+
dim = args[1] if len(args) > 1 else kwargs.get("dim", -1)
47+
48+
if "val" not in X.node.meta:
49+
raise Exception("Could not get dimension metadata in input.")
50+
51+
if dim < 0:
52+
dim += X.node.meta["val"].dim()
53+
54+
n = X.node.meta["val"].size(dim)
55+
56+
if n % 2:
57+
raise RuntimeError(
58+
f"glu expects an even split along dim={dim}, got size {n}"
59+
)
60+
61+
middle = n // 2
62+
63+
T1 = super().call_operator(
64+
slice_op, (X, dim, 0, middle), {}, meta, updated=True
65+
)
66+
67+
T2 = super().call_operator(
68+
slice_op, (X, dim, middle, n), {}, meta, updated=True
69+
)
70+
71+
T2_sigmoid = super().call_operator(sigmoid, (T2,), {}, meta, updated=True)
72+
73+
return super().call_operator(
74+
hadamard_prod, (T1, T2_sigmoid), {}, meta, updated=True
75+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def is_node_supported(
258258
exir_ops.edge.aten.masked_fill.Scalar,
259259
exir_ops.edge.aten.asinh.default,
260260
exir_ops.edge.aten.cosh.default,
261+
exir_ops.edge.aten.glu.default,
261262
]
262263

263264
return supported
@@ -299,6 +300,7 @@ def is_node_supported(
299300
exir_ops.edge.aten.leaky_relu.default: None,
300301
exir_ops.edge.aten.round.default: None,
301302
exir_ops.edge.aten.addmm.default: None,
303+
exir_ops.edge.aten.glu.default: None,
302304
}
303305

304306
if node.target in needs_decomp_dict:

backends/arm/test/ops/test_glu.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
aten_op = "torch.ops.aten.glu.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten__glu_default"
21+
22+
23+
input_t1 = Tuple[torch.Tensor]
24+
25+
test_data_suite = {
26+
"zeros": [torch.zeros(10, 10, 2), -1],
27+
"ones": [torch.ones(10, 10, 2), -1],
28+
"rand": [torch.rand(10, 10, 2) - 0.5, -1],
29+
"randn_pos": [torch.randn(10, 2) + 10, -1],
30+
"randn_neg": [torch.randn(10, 2) - 10, -1],
31+
"ramp": [torch.linspace(-16, 15.8, 160).reshape(-1, 2), -1],
32+
"zeros_custom_dim": [torch.zeros(7, 10, 5), 1],
33+
"rand_custom_dim": [torch.rand(10, 3, 3) - 0.5, 0],
34+
}
35+
36+
37+
class Glu(torch.nn.Module):
38+
39+
def forward(self, a: torch.Tensor, dim: int) -> torch.Tensor:
40+
return F.glu(a, dim=dim)
41+
42+
43+
@common.parametrize(
44+
"test_data",
45+
test_data_suite,
46+
)
47+
def test_glu_tosa_FP(test_data: Tuple):
48+
pipeline = TosaPipelineFP[input_t1](
49+
Glu(),
50+
(*test_data,),
51+
aten_op,
52+
exir_op,
53+
)
54+
pipeline.run()
55+
56+
57+
@common.parametrize(
58+
"test_data",
59+
test_data_suite,
60+
)
61+
def test_glu_tosa_INT(test_data: Tuple):
62+
pipeline = TosaPipelineINT[input_t1](
63+
Glu(),
64+
(*test_data,),
65+
aten_op=[],
66+
exir_op=exir_op,
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize(
72+
"test_data",
73+
test_data_suite,
74+
)
75+
@common.XfailIfNoCorstone300
76+
def test_glu_u55_INT(test_data: Tuple):
77+
pipeline = EthosU55PipelineINT[input_t1](
78+
Glu(),
79+
(*test_data,),
80+
aten_ops=[],
81+
exir_ops=exir_op,
82+
)
83+
pipeline.run()
84+
85+
86+
@common.parametrize(
87+
"test_data",
88+
test_data_suite,
89+
)
90+
@common.XfailIfNoCorstone320
91+
def test_glu_u85_INT(test_data: Tuple):
92+
pipeline = EthosU85PipelineINT[input_t1](
93+
Glu(),
94+
(*test_data,),
95+
aten_ops=[],
96+
exir_ops=exir_op,
97+
)
98+
pipeline.run()
99+
100+
101+
@common.parametrize(
102+
"test_data",
103+
test_data_suite,
104+
)
105+
@common.SkipIfNoModelConverter
106+
def test_glu_vgf_FP(test_data: input_t1):
107+
pipeline = VgfPipeline[input_t1](
108+
Glu(),
109+
(*test_data,),
110+
[],
111+
[],
112+
tosa_version="TOSA-1.0+FP",
113+
)
114+
pipeline.run()
115+
116+
117+
@common.parametrize(
118+
"test_data",
119+
test_data_suite,
120+
)
121+
@common.SkipIfNoModelConverter
122+
def test_glu_vgf_INT(test_data: input_t1):
123+
pipeline = VgfPipeline[input_t1](
124+
Glu(),
125+
(*test_data,),
126+
[],
127+
[],
128+
tosa_version="TOSA-1.0+INT",
129+
)
130+
pipeline.run()

0 commit comments

Comments
 (0)