Skip to content

Arm backend: Add GLU decomposition pass and test #13270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DecomposeDivPass,
DecomposeEmbeddingPass,
DecomposeGeluPass,
DecomposeGluPass,
DecomposeGroupedConv,
DecomposeGroupNormPass,
DecomposeLayerNormPass,
Expand Down Expand Up @@ -184,6 +185,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(FuseBatchnorm2DPass(exported_program))
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeGluPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeGroupNormPass())
Expand Down Expand Up @@ -264,6 +266,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeCosineSimilarityPass())
self.add_pass(DecomposeGluPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeLinearVectorNormPass())
Expand Down
75 changes: 75 additions & 0 deletions backends/arm/_passes/decompose_glu_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops


# For FP case
edge_glu = exir_ops.edge.aten.glu.default

# For INT case
aten_glu = torch.ops.aten.glu.default


def get_ops(op):
"""Returns the appropriate operator functions based on the input operator."""
if op == edge_glu:
return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
)
elif op == aten_glu:
return (
torch.ops.aten.mul.Tensor,
torch.ops.aten.sigmoid.default,
torch.ops.aten.slice_copy.Tensor,
)
else:
raise ValueError(f"Unsupported operator: {op}")


class DecomposeGluPass(ArmPass):
"""Decomposes the GLU operator into hadamard product and sigmoid."""

def call_operator(self, op, args, kwargs, meta):
if op not in [edge_glu, aten_glu]:
return super().call_operator(op, args, kwargs, meta)

hadamard_prod, sigmoid, slice_op = get_ops(op)
X = args[0]

dim = args[1] if len(args) > 1 else kwargs.get("dim", -1)

if "val" not in X.node.meta:
raise Exception("Could not get dimension metadata in input.")

if dim < 0:
dim += X.node.meta["val"].dim()

n = X.node.meta["val"].size(dim)

if n % 2:
raise RuntimeError(
f"glu expects an even split along dim={dim}, got size {n}"
)

middle = n // 2

T1 = super().call_operator(
slice_op, (X, dim, 0, middle), {}, meta, updated=True
)

T2 = super().call_operator(
slice_op, (X, dim, middle, n), {}, meta, updated=True
)

T2_sigmoid = super().call_operator(sigmoid, (T2,), {}, meta, updated=True)

return super().call_operator(
hadamard_prod, (T1, T2_sigmoid), {}, meta, updated=True
)
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def is_node_supported(
exir_ops.edge.aten.masked_fill.Scalar,
exir_ops.edge.aten.asinh.default,
exir_ops.edge.aten.cosh.default,
exir_ops.edge.aten.glu.default,
]

return supported
Expand Down Expand Up @@ -299,6 +300,7 @@ def is_node_supported(
exir_ops.edge.aten.leaky_relu.default: None,
exir_ops.edge.aten.round.default: None,
exir_ops.edge.aten.addmm.default: None,
exir_ops.edge.aten.glu.default: None,
}

if node.target in needs_decomp_dict:
Expand Down
130 changes: 130 additions & 0 deletions backends/arm/test/ops/test_glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
import torch.nn.functional as F
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

aten_op = "torch.ops.aten.glu.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__glu_default"


input_t1 = Tuple[torch.Tensor]

test_data_suite = {
"zeros": [torch.zeros(10, 10, 2), -1],
"ones": [torch.ones(10, 10, 2), -1],
"rand": [torch.rand(10, 10, 2) - 0.5, -1],
"randn_pos": [torch.randn(10, 2) + 10, -1],
"randn_neg": [torch.randn(10, 2) - 10, -1],
"ramp": [torch.linspace(-16, 15.8, 160).reshape(-1, 2), -1],
"zeros_custom_dim": [torch.zeros(7, 10, 5), 1],
"rand_custom_dim": [torch.rand(10, 3, 3) - 0.5, 0],
}


class Glu(torch.nn.Module):

def forward(self, a: torch.Tensor, dim: int) -> torch.Tensor:
return F.glu(a, dim=dim)


@common.parametrize(
"test_data",
test_data_suite,
)
def test_glu_tosa_FP(test_data: Tuple):
pipeline = TosaPipelineFP[input_t1](
Glu(),
(*test_data,),
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize(
"test_data",
test_data_suite,
)
def test_glu_tosa_INT(test_data: Tuple):
pipeline = TosaPipelineINT[input_t1](
Glu(),
(*test_data,),
aten_op=[],
exir_op=exir_op,
)
pipeline.run()


@common.parametrize(
"test_data",
test_data_suite,
)
@common.XfailIfNoCorstone300
def test_glu_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t1](
Glu(),
(*test_data,),
aten_ops=[],
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize(
"test_data",
test_data_suite,
)
@common.XfailIfNoCorstone320
def test_glu_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t1](
Glu(),
(*test_data,),
aten_ops=[],
exir_ops=exir_op,
)
pipeline.run()


@common.parametrize(
"test_data",
test_data_suite,
)
@common.SkipIfNoModelConverter
def test_glu_vgf_FP(test_data: input_t1):
pipeline = VgfPipeline[input_t1](
Glu(),
(*test_data,),
[],
[],
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize(
"test_data",
test_data_suite,
)
@common.SkipIfNoModelConverter
def test_glu_vgf_INT(test_data: input_t1):
pipeline = VgfPipeline[input_t1](
Glu(),
(*test_data,),
[],
[],
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading