diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index a077bfd0b2b..34ac00ad1f6 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -35,6 +35,7 @@ from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa +from .decompose_expm1_pass import DecomposeExpm1Pass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5de0dd89aac..a17fbf69303 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -40,6 +40,7 @@ DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, + DecomposeExpm1Pass, DecomposeGeluPass, DecomposeGluPass, DecomposeGroupedConv, @@ -164,6 +165,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: return self._transform(exported_program.graph_module) def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeMaskedFill()) self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeAcoshPass()) diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py new file mode 100644 index 00000000000..5b1b90495b5 --- /dev/null +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -0,0 +1,135 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case + + +def _get_expm1_decomposition(op) -> tuple: + """ + Returns the decomposition of the given aten.expm1 operation into + its equivalent TOSA-supported operations + + This handles both edge dialect ops and core PyTorch ops. The decomposition strategy + is: + expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1)) + + where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24) + + Returns: + A tuple (op_pow, op_div, op_add, op_exp, op_sub, op_ge, op_where, op_le, op_and) + corresponding to the appropriate operator overloads for the input op. + + Raises: + RuntimeError: If the provided operator is not a supported elu variant. + """ + if op in edge_expm1_ops: + return ( + exir_ops.edge.aten.pow.Tensor_Scalar, + exir_ops.edge.aten.div.Scalar, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.ge.Scalar, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.le.Scalar, + exir_ops.edge.aten.logical_and.default, + ) + + raise RuntimeError(f"Can't get expm1 decomposition for op {op}") + + +class DecomposeExpm1Pass(ArmPass): + """ + A transformation pass that decomposes unsupported 'aten.expm1' operations + into a combination of supported TOSA-equivalent operations. + + Since TOSA does not provide a native expm1 operator, this pass rewrites: + expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1)) + where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24) + + Supported input ops: + - exir_ops.edge.aten.expm1.default(x) + + These are replaced with: + - exir_ops.edge.aten.pow.Tensor_Scalar, + - exir_ops.edge.aten.div.Scalar, + - exir_ops.edge.aten.add.Tensor, + - exir_ops.edge.aten.exp.default, + - exir_ops.edge.aten.sub.Scalar, + - exir_ops.edge.aten.ge.Scalar, + - exir_ops.edge.aten.where.self, + - exir_ops.edge.aten.le.Scalar, + - exir_ops.edge.aten.logical_and.default + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in edge_expm1_ops: + return super().call_operator(op, args, kwargs, meta, updated=False) + + ( + op_pow, + op_div, + op_add, + op_exp, + op_sub, + op_ge, + op_where, + op_le, + op_and, + ) = _get_expm1_decomposition(op) + + input = args[0] + + cutlo = -0.35 + cuthi = 0.35 + + taylor_term_2_numerator = super().call_operator( + op_pow, (input, 2), {}, meta, updated=False + ) + taylor_term_3_numerator = super().call_operator( + op_pow, (input, 3), {}, meta, updated=False + ) + taylor_term_4_numerator = super().call_operator( + op_pow, (input, 4), {}, meta, updated=False + ) + + taylor_term_2 = super().call_operator( + op_div, (taylor_term_2_numerator, 2), {}, meta, updated=False + ) + taylor_term_3 = super().call_operator( + op_div, (taylor_term_3_numerator, 6), {}, meta, updated=False + ) + taylor_term_4 = super().call_operator( + op_div, (taylor_term_4_numerator, 24), {}, meta, updated=False + ) + + add_terms_1_2 = super().call_operator( + op_add, (input, taylor_term_2), {}, meta, updated=False + ) + add_term_3 = super().call_operator( + op_add, (add_terms_1_2, taylor_term_3), {}, meta, updated=False + ) + taylor_expansion = super().call_operator( + op_add, (add_term_3, taylor_term_4), {}, meta, updated=False + ) + + decomp_exp = super().call_operator(op_exp, (input,), {}, meta, updated=False) + decomp_sub = super().call_operator( + op_sub, (decomp_exp, 1.0), {}, meta, updated=False + ) + + ge = super().call_operator(op_ge, (input, cutlo), {}, meta, updated=False) + le = super().call_operator(op_le, (input, cuthi), {}, meta, updated=False) + + cond_and = super().call_operator(op_and, (ge, le), {}, meta, updated=False) + where = super().call_operator( + op_where, (cond_and, taylor_expansion, decomp_sub), {}, meta, updated=True + ) + + return where diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 4b4d8078aa5..bead5e993f5 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -43,6 +43,7 @@ class TableOps: exir_ops.edge.aten.ceil.default: torch.ceil, exir_ops.edge.aten.erf.default: torch.erf, exir_ops.edge.aten.exp.default: torch.exp, + exir_ops.edge.aten.expm1.default: torch.expm1, exir_ops.edge.aten.floor.default: torch.floor, exir_ops.edge.aten.log.default: torch.log, exir_ops.edge.aten.reciprocal.default: torch.reciprocal, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 9b472cd677b..c8c147d6046 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -179,6 +179,7 @@ def is_node_supported( exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.erf.default, exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.expm1.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5504a47f41b..a6f5671a881 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -265,6 +265,7 @@ def _match_pattern( torch.ops.aten.ceil.default, torch.ops.aten.erf.default, torch.ops.aten.exp.default, + torch.ops.aten.expm1.default, torch.ops.aten.floor.default, torch.ops.aten.log.default, torch.ops.aten.reciprocal.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index a6d2ca9f2eb..9ceb5d73d23 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -8,6 +8,7 @@ CUSTOM_EDGE_OPS = [ "linspace.default", "eye.default", + "expm1.default", "vector_norm.default", "hardsigmoid.default", "hardswish.default", diff --git a/backends/arm/test/ops/test_expm1.py b/backends/arm/test/ops/test_expm1.py new file mode 100644 index 00000000000..dad95b24f7b --- /dev/null +++ b/backends/arm/test/ops/test_expm1.py @@ -0,0 +1,113 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.expm1.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_expm1_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "zeroes": torch.zeros(1, 10, 10, 10), + "ones": torch.ones(10, 2, 3), + "rand": torch.rand(10, 10) - 0.5, + "near_zero": torch.randn(100) * 0.01, + "taylor_small": torch.empty(5).uniform_( + -0.35, 0.35 + ), # test cases for taylor series expansion + "randn_large_pos": torch.randn(10) + 10, + "randn_large_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), +} + + +class Expm1(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.expm1(x) + + +@common.parametrize("test_data", test_data_suite) +def test_expm1_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t1]( + Expm1(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_expm1_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t1]( + Expm1(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_expm1_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + Expm1(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_expm1_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + Expm1(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_expm1_vgf_FP(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + Expm1(), + (test_data,), + aten_op, + exir_op, + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_expm1_vgf_INT(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + Expm1(), + (test_data,), + aten_op, + exir_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()