diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f9cf838a526..2a75606cb70 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -22,6 +22,7 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa +from .decompose_atan_pass import DecomposeAtanPass # noqa from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f0a86b1ce84..596decd65bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -25,6 +25,7 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, + DecomposeAtanPass, DecomposeAvgPool2d, DecomposeBatchNormNoStatsPass, DecomposeCosineSimilarityPass, @@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) + self.add_pass(DecomposeAtanPass()) self.add_pass(ConvertIntPowToMuls()) self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py new file mode 100644 index 00000000000..57b9dde5216 --- /dev/null +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -0,0 +1,119 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from math import pi + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +edge_atan = exir_ops.edge.aten.atan.default # MI case + + +def _get_atan_ops(op): + """Return the primitive ops required..""" + if op is not edge_atan: + raise RuntimeError(f"Can't decompose atan for op {op}") + + return ( + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.gt.Scalar, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.neg.default, + ) + + +class DecomposeAtanPass(ArmPass): + """Decomposes the atan operator into a rational (Padé) approximation.""" + + def _rational_approximation(self, z, ops, meta): + """Creates a (2,1) Padé approximation for atan(x) on [-1, 1].""" + + op_mul, op_mul_scalar, op_add, op_add_scalar, _, _, _, op_recip, _, _ = ops + + # Coefficients calculated using minimax on the interval [-1, 1]. + a1 = 0.3529666667 + a2 = -0.0287666667 + b1 = 0.6863 + + z2 = super().call_operator(op_mul, (z, z), {}, meta, updated=True) + z4 = super().call_operator(op_mul, (z2, z2), {}, meta, updated=True) + + num1 = super().call_operator(op_mul_scalar, (z2, a1), {}, meta, updated=True) + num2 = super().call_operator(op_mul_scalar, (z4, a2), {}, meta, updated=True) + num = super().call_operator(op_add_scalar, (num1, 1.0), {}, meta, updated=True) + num = super().call_operator(op_add, (num, num2), {}, meta, updated=True) + + den1 = super().call_operator(op_mul_scalar, (z2, b1), {}, meta, updated=True) + den = super().call_operator(op_add_scalar, (den1, 1.0), {}, meta, updated=True) + + inv_den = super().call_operator(op_recip, (den,), {}, meta, updated=True) + + prod = super().call_operator(op_mul, (num, inv_den), {}, meta, updated=True) + return super().call_operator(op_mul, (z, prod), {}, meta, updated=True) + + def call_operator(self, op, args, kwargs, meta): + if op is not edge_atan: + return super().call_operator(op, args, kwargs, meta, updated=False) + + logging.info( + f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}." + ) + + ops = _get_atan_ops(op) + ( + _, + op_mul_scalar, + _, + op_add_scalar, + op_sub, + op_abs, + op_gt, + op_recip, + op_where, + op_neg, + ) = ops + + x = args[0] + + # |x| > 1 is reduced to [0, 1] using atan(x) = pi/2 - atan(1/x) and atan(-x) = -atan(x). + + abs_x = super().call_operator(op_abs, (x,), {}, meta, updated=True) + mask_hi = super().call_operator(op_gt, (abs_x, 1.0), {}, meta, updated=True) + + inv_x = super().call_operator(op_recip, (abs_x,), {}, meta, updated=True) + z = super().call_operator( + op_where, (mask_hi, inv_x, abs_x), {}, meta, updated=True + ) + + atan_z = self._rational_approximation(z, ops, meta) + + zero_tensor = super().call_operator( + op_mul_scalar, (x, 0.0), {}, meta, updated=True + ) + half_pi_tensor = super().call_operator( + op_add_scalar, (zero_tensor, pi / 2), {}, meta, updated=True + ) + + diff = super().call_operator( + op_sub, (half_pi_tensor, atan_z), {}, meta, updated=True + ) + atan_abs = super().call_operator( + op_where, (mask_hi, diff, atan_z), {}, meta, updated=True + ) + + mask_pos = super().call_operator(op_gt, (x, 0.0), {}, meta, updated=True) + neg_val = super().call_operator(op_neg, (atan_abs,), {}, meta, updated=True) + + return super().call_operator( + op_where, (mask_pos, atan_abs, neg_val), {}, meta, updated=True + ) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index c579fcb0301..b31b6c7106d 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -51,6 +51,7 @@ class TableOps: exir_ops.edge.aten.cos.default: torch.cos, exir_ops.edge.aten.sin.default: torch.sin, exir_ops.edge.aten.tanh.default: torch.tanh, + exir_ops.edge.aten.atan.default: torch.atan, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, exir_ops.edge.aten.sinh.default: torch.sinh, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 639df536109..cdb27b7c31e 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -244,6 +244,7 @@ def is_node_supported( exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.alias_copy.default, exir_ops.edge.aten.sinh.default, + exir_ops.edge.aten.atan.default, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index c6415c63777..2c61aea60c3 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -214,6 +214,7 @@ def _match_pattern( torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, + torch.ops.aten.atan.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_atan.py b/backends/arm/test/ops/test_atan.py new file mode 100644 index 00000000000..3d6f8cd8fa8 --- /dev/null +++ b/backends/arm/test/ops/test_atan.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.atan.default" +exir_op = "executorch_exir_dialects_edge__ops_aten__atan_default" + +input_t1 = Tuple[torch.Tensor] + +test_data_suite = { + "zeros": torch.zeros(1, 10, 10, 10), + "zeros_alt_shape": torch.zeros(1, 10, 3, 5), + "ones": torch.ones(10, 10, 10), + "rand": torch.rand(10, 10) - 0.5, + "rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5, + "randn_pos": torch.randn(10) + 10, + "randn_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), +} + + +class Atan(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.atan(x) + + +@common.parametrize("test_data", test_data_suite) +def test_atan_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Atan(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_atan_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Atan(), + (test_data,), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) +def test_atan_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Atan(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_suite) +def test_atan_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Atan(), + (test_data,), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run()