diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index d7d07553a1e..a077bfd0b2b 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -36,6 +36,7 @@ from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa +from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d7434a534fe..5de0dd89aac 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -41,6 +41,7 @@ DecomposeDivPass, DecomposeEmbeddingPass, DecomposeGeluPass, + DecomposeGluPass, DecomposeGroupedConv, DecomposeGroupNormPass, DecomposeLayerNormPass, @@ -184,6 +185,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ConvertSplitToSlicePass()) self.add_pass(FuseBatchnorm2DPass(exported_program)) self.add_pass(ConvertMmToBmmPass()) + self.add_pass(DecomposeGluPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeGroupNormPass()) @@ -264,6 +266,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec)) self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeCosineSimilarityPass()) + self.add_pass(DecomposeGluPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeLinearVectorNormPass()) diff --git a/backends/arm/_passes/decompose_glu_pass.py b/backends/arm/_passes/decompose_glu_pass.py new file mode 100644 index 00000000000..183dc89cf61 --- /dev/null +++ b/backends/arm/_passes/decompose_glu_pass.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +# For FP case +edge_glu = exir_ops.edge.aten.glu.default + +# For INT case +aten_glu = torch.ops.aten.glu.default + + +def get_ops(op): + """Returns the appropriate operator functions based on the input operator.""" + if op == edge_glu: + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.slice_copy.Tensor, + ) + elif op == aten_glu: + return ( + torch.ops.aten.mul.Tensor, + torch.ops.aten.sigmoid.default, + torch.ops.aten.slice_copy.Tensor, + ) + else: + raise ValueError(f"Unsupported operator: {op}") + + +class DecomposeGluPass(ArmPass): + """Decomposes the GLU operator into hadamard product and sigmoid.""" + + def call_operator(self, op, args, kwargs, meta): + if op not in [edge_glu, aten_glu]: + return super().call_operator(op, args, kwargs, meta) + + hadamard_prod, sigmoid, slice_op = get_ops(op) + X = args[0] + + dim = args[1] if len(args) > 1 else kwargs.get("dim", -1) + + if "val" not in X.node.meta: + raise Exception("Could not get dimension metadata in input.") + + if dim < 0: + dim += X.node.meta["val"].dim() + + n = X.node.meta["val"].size(dim) + + if n % 2: + raise RuntimeError( + f"glu expects an even split along dim={dim}, got size {n}" + ) + + middle = n // 2 + + T1 = super().call_operator( + slice_op, (X, dim, 0, middle), {}, meta, updated=True + ) + + T2 = super().call_operator( + slice_op, (X, dim, middle, n), {}, meta, updated=True + ) + + T2_sigmoid = super().call_operator(sigmoid, (T2,), {}, meta, updated=True) + + return super().call_operator( + hadamard_prod, (T1, T2_sigmoid), {}, meta, updated=True + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index befd451493b..9b472cd677b 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -258,6 +258,7 @@ def is_node_supported( exir_ops.edge.aten.masked_fill.Scalar, exir_ops.edge.aten.asinh.default, exir_ops.edge.aten.cosh.default, + exir_ops.edge.aten.glu.default, ] return supported @@ -299,6 +300,7 @@ def is_node_supported( exir_ops.edge.aten.leaky_relu.default: None, exir_ops.edge.aten.round.default: None, exir_ops.edge.aten.addmm.default: None, + exir_ops.edge.aten.glu.default: None, } if node.target in needs_decomp_dict: diff --git a/backends/arm/test/ops/test_glu.py b/backends/arm/test/ops/test_glu.py new file mode 100644 index 00000000000..c19fb892c92 --- /dev/null +++ b/backends/arm/test/ops/test_glu.py @@ -0,0 +1,130 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.glu.default" +exir_op = "executorch_exir_dialects_edge__ops_aten__glu_default" + + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "zeros": [torch.zeros(10, 10, 2), -1], + "ones": [torch.ones(10, 10, 2), -1], + "rand": [torch.rand(10, 10, 2) - 0.5, -1], + "randn_pos": [torch.randn(10, 2) + 10, -1], + "randn_neg": [torch.randn(10, 2) - 10, -1], + "ramp": [torch.linspace(-16, 15.8, 160).reshape(-1, 2), -1], + "zeros_custom_dim": [torch.zeros(7, 10, 5), 1], + "rand_custom_dim": [torch.rand(10, 3, 3) - 0.5, 0], +} + + +class Glu(torch.nn.Module): + + def forward(self, a: torch.Tensor, dim: int) -> torch.Tensor: + return F.glu(a, dim=dim) + + +@common.parametrize( + "test_data", + test_data_suite, +) +def test_glu_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + Glu(), + (*test_data,), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, +) +def test_glu_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + Glu(), + (*test_data,), + aten_op=[], + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, +) +@common.XfailIfNoCorstone300 +def test_glu_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + Glu(), + (*test_data,), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, +) +@common.XfailIfNoCorstone320 +def test_glu_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + Glu(), + (*test_data,), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, +) +@common.SkipIfNoModelConverter +def test_glu_vgf_FP(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + Glu(), + (*test_data,), + [], + [], + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, +) +@common.SkipIfNoModelConverter +def test_glu_vgf_INT(test_data: input_t1): + pipeline = VgfPipeline[input_t1]( + Glu(), + (*test_data,), + [], + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()