Skip to content

Commit 3816878

Browse files
authored
Merge branch 'main' into Add-support-for-expm1
2 parents ffe8925 + 7535720 commit 3816878

File tree

6 files changed

+254
-44
lines changed

6 files changed

+254
-44
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3838
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
3939
from .decompose_gelu_pass import DecomposeGeluPass # noqa
40+
from .decompose_glu_pass import DecomposeGluPass # noqa
4041
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4142
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
4243
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeEmbeddingPass,
4343
DecomposeExpm1Pass,
4444
DecomposeGeluPass,
45+
DecomposeGluPass,
4546
DecomposeGroupedConv,
4647
DecomposeGroupNormPass,
4748
DecomposeLayerNormPass,
@@ -186,6 +187,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
186187
self.add_pass(ConvertSplitToSlicePass())
187188
self.add_pass(FuseBatchnorm2DPass(exported_program))
188189
self.add_pass(ConvertMmToBmmPass())
190+
self.add_pass(DecomposeGluPass())
189191
self.add_pass(DecomposeLinearPass())
190192
self.add_pass(DecomposeLeakyReLUPass())
191193
self.add_pass(DecomposeGroupNormPass())
@@ -266,6 +268,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
266268
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
267269
self.add_pass(DecomposeNotEqualPass())
268270
self.add_pass(DecomposeCosineSimilarityPass())
271+
self.add_pass(DecomposeGluPass())
269272
self.add_pass(DecomposeDivPass())
270273
self.add_pass(DecomposeLeakyReLUPass())
271274
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
11+
# For FP case
12+
edge_glu = exir_ops.edge.aten.glu.default
13+
14+
# For INT case
15+
aten_glu = torch.ops.aten.glu.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_glu:
21+
return (
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.sigmoid.default,
24+
exir_ops.edge.aten.slice_copy.Tensor,
25+
)
26+
elif op == aten_glu:
27+
return (
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.sigmoid.default,
30+
torch.ops.aten.slice_copy.Tensor,
31+
)
32+
else:
33+
raise ValueError(f"Unsupported operator: {op}")
34+
35+
36+
class DecomposeGluPass(ArmPass):
37+
"""Decomposes the GLU operator into hadamard product and sigmoid."""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in [edge_glu, aten_glu]:
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
hadamard_prod, sigmoid, slice_op = get_ops(op)
44+
X = args[0]
45+
46+
dim = args[1] if len(args) > 1 else kwargs.get("dim", -1)
47+
48+
if "val" not in X.node.meta:
49+
raise Exception("Could not get dimension metadata in input.")
50+
51+
if dim < 0:
52+
dim += X.node.meta["val"].dim()
53+
54+
n = X.node.meta["val"].size(dim)
55+
56+
if n % 2:
57+
raise RuntimeError(
58+
f"glu expects an even split along dim={dim}, got size {n}"
59+
)
60+
61+
middle = n // 2
62+
63+
T1 = super().call_operator(
64+
slice_op, (X, dim, 0, middle), {}, meta, updated=True
65+
)
66+
67+
T2 = super().call_operator(
68+
slice_op, (X, dim, middle, n), {}, meta, updated=True
69+
)
70+
71+
T2_sigmoid = super().call_operator(sigmoid, (T2,), {}, meta, updated=True)
72+
73+
return super().call_operator(
74+
hadamard_prod, (T1, T2_sigmoid), {}, meta, updated=True
75+
)

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import hashlib
7+
from collections import defaultdict
8+
69
import torch
710
from executorch.backends.arm._passes.arm_pass_utils import (
811
get_constant_placeholder_kind,
@@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass):
2124
"""
2225
This pass optimizes memory usage by finding constant placeholders
2326
pointing to identical tensors and fusing them to one single placeholder
24-
with multiple users.
27+
with multiple users, using a cache for faster comparison.
2528
"""
2629

2730
def __init__(self, exported_program: ExportedProgram):
@@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram):
3033

3134
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3235
modified = False
33-
const_placeholder_nodes = []
34-
for node in graph_module.graph.nodes:
35-
if is_param_node(self.exported_program, node):
36-
const_placeholder_nodes.append(node)
37-
38-
while const_placeholder_nodes:
3936

40-
# Find equal tensors
41-
node1 = const_placeholder_nodes.pop()
42-
eq_nodes = [node1]
43-
tensor1 = get_param_tensor(self.exported_program, node1)
44-
if tensor1 is None:
37+
# Build a cache of params: mapping hash_key -> list of (node, tensor)
38+
hash_buckets = defaultdict(list)
39+
for node in graph_module.graph.nodes:
40+
if not is_param_node(self.exported_program, node):
4541
continue
42+
tensor = get_param_tensor(self.exported_program, node)
43+
if tensor is None:
44+
continue
45+
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
46+
# Ensure tensor is on CPU and contiguous
47+
t_cpu = tensor.detach().cpu().contiguous()
48+
data_bytes = t_cpu.numpy().tobytes()
49+
key = (
50+
str(t_cpu.dtype),
51+
tuple(t_cpu.shape),
52+
hashlib.sha1(data_bytes).hexdigest(),
53+
)
54+
hash_buckets[key].append((node, t_cpu))
4655

47-
for node2 in const_placeholder_nodes:
48-
tensor2 = get_param_tensor(self.exported_program, node2)
49-
if tensor2 is None:
50-
continue
51-
52-
if (
53-
tensor1.dtype == tensor2.dtype
54-
and tensor1.shape == tensor2.shape
55-
and torch.allclose(tensor1, tensor2, atol=1e-08)
56-
):
57-
eq_nodes.append(node2)
56+
# For each bucket with more than one entry, fuse:
57+
for nodes_tensors in hash_buckets.values():
58+
if len(nodes_tensors) < 2:
59+
continue
5860

59-
if len(eq_nodes) > 1:
60-
common_name = node1.name + "_common"
61-
common_kind = get_constant_placeholder_kind(
62-
self.exported_program, node1
61+
# Create a new placeholder from first in list of equal placeholders.
62+
rep_node, rep_tensor = nodes_tensors[0]
63+
common_name = rep_node.name + "_common"
64+
common_kind = get_constant_placeholder_kind(self.exported_program, rep_node)
65+
common_persistent = True
66+
with graph_module.graph.inserting_before(rep_node):
67+
common_node = create_constant_placeholder(
68+
self.exported_program,
69+
graph_module.graph,
70+
common_name,
71+
common_kind,
72+
rep_tensor,
73+
common_persistent,
6374
)
64-
common_persisten_buffer = True
65-
66-
with graph_module.graph.inserting_before(node1):
67-
common_node = create_constant_placeholder(
68-
self.exported_program,
69-
graph_module.graph,
70-
common_name,
71-
common_kind,
72-
tensor1,
73-
common_persisten_buffer,
74-
)
75-
76-
for eq_node in eq_nodes:
77-
eq_node.replace_all_uses_with(common_node)
78-
delete_constant_placeholder(self.exported_program, eq_node)
79-
if eq_node != node1:
80-
const_placeholder_nodes.remove(eq_node)
8175

76+
# Replace uses and delete duplicates
77+
for node, _ in nodes_tensors:
78+
node.replace_all_uses_with(common_node)
79+
delete_constant_placeholder(self.exported_program, node)
8280
modified = True
8381

8482
if modified:
8583
graph_module.recompile()
8684
graph_module = super().call(graph_module).graph_module
85+
8786
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def is_node_supported(
259259
exir_ops.edge.aten.masked_fill.Scalar,
260260
exir_ops.edge.aten.asinh.default,
261261
exir_ops.edge.aten.cosh.default,
262+
exir_ops.edge.aten.glu.default,
262263
]
263264

264265
return supported
@@ -300,6 +301,7 @@ def is_node_supported(
300301
exir_ops.edge.aten.leaky_relu.default: None,
301302
exir_ops.edge.aten.round.default: None,
302303
exir_ops.edge.aten.addmm.default: None,
304+
exir_ops.edge.aten.glu.default: None,
303305
}
304306

305307
if node.target in needs_decomp_dict:

backends/arm/test/ops/test_glu.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
aten_op = "torch.ops.aten.glu.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten__glu_default"
21+
22+
23+
input_t1 = Tuple[torch.Tensor]
24+
25+
test_data_suite = {
26+
"zeros": [torch.zeros(10, 10, 2), -1],
27+
"ones": [torch.ones(10, 10, 2), -1],
28+
"rand": [torch.rand(10, 10, 2) - 0.5, -1],
29+
"randn_pos": [torch.randn(10, 2) + 10, -1],
30+
"randn_neg": [torch.randn(10, 2) - 10, -1],
31+
"ramp": [torch.linspace(-16, 15.8, 160).reshape(-1, 2), -1],
32+
"zeros_custom_dim": [torch.zeros(7, 10, 5), 1],
33+
"rand_custom_dim": [torch.rand(10, 3, 3) - 0.5, 0],
34+
}
35+
36+
37+
class Glu(torch.nn.Module):
38+
39+
def forward(self, a: torch.Tensor, dim: int) -> torch.Tensor:
40+
return F.glu(a, dim=dim)
41+
42+
43+
@common.parametrize(
44+
"test_data",
45+
test_data_suite,
46+
)
47+
def test_glu_tosa_FP(test_data: Tuple):
48+
pipeline = TosaPipelineFP[input_t1](
49+
Glu(),
50+
(*test_data,),
51+
aten_op,
52+
exir_op,
53+
)
54+
pipeline.run()
55+
56+
57+
@common.parametrize(
58+
"test_data",
59+
test_data_suite,
60+
)
61+
def test_glu_tosa_INT(test_data: Tuple):
62+
pipeline = TosaPipelineINT[input_t1](
63+
Glu(),
64+
(*test_data,),
65+
aten_op=[],
66+
exir_op=exir_op,
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize(
72+
"test_data",
73+
test_data_suite,
74+
)
75+
@common.XfailIfNoCorstone300
76+
def test_glu_u55_INT(test_data: Tuple):
77+
pipeline = EthosU55PipelineINT[input_t1](
78+
Glu(),
79+
(*test_data,),
80+
aten_ops=[],
81+
exir_ops=exir_op,
82+
)
83+
pipeline.run()
84+
85+
86+
@common.parametrize(
87+
"test_data",
88+
test_data_suite,
89+
)
90+
@common.XfailIfNoCorstone320
91+
def test_glu_u85_INT(test_data: Tuple):
92+
pipeline = EthosU85PipelineINT[input_t1](
93+
Glu(),
94+
(*test_data,),
95+
aten_ops=[],
96+
exir_ops=exir_op,
97+
)
98+
pipeline.run()
99+
100+
101+
@common.parametrize(
102+
"test_data",
103+
test_data_suite,
104+
)
105+
@common.SkipIfNoModelConverter
106+
def test_glu_vgf_FP(test_data: input_t1):
107+
pipeline = VgfPipeline[input_t1](
108+
Glu(),
109+
(*test_data,),
110+
[],
111+
[],
112+
tosa_version="TOSA-1.0+FP",
113+
)
114+
pipeline.run()
115+
116+
117+
@common.parametrize(
118+
"test_data",
119+
test_data_suite,
120+
)
121+
@common.SkipIfNoModelConverter
122+
def test_glu_vgf_INT(test_data: input_t1):
123+
pipeline = VgfPipeline[input_t1](
124+
Glu(),
125+
(*test_data,),
126+
[],
127+
[],
128+
tosa_version="TOSA-1.0+INT",
129+
)
130+
pipeline.run()

0 commit comments

Comments
 (0)