diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 93bf20e69c1..4a86868f0b1 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -42,6 +42,7 @@ from .decompose_elu_pass import DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa +from .decompose_fmod_pass import DecomposeFmodPass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b7c511bbe0b..2133a3bd803 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -51,6 +51,7 @@ DecomposeEluPass, DecomposeEmbeddingPass, DecomposeExpm1Pass, + DecomposeFmodPass, DecomposeGeluPass, DecomposeGluPass, DecomposeGroupedConv, @@ -240,6 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) + self.add_pass(DecomposeFmodPass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -338,6 +340,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeCosineSimilarityPass()) self.add_pass(DecomposeGluPass()) + + if not self.tosa_spec.is_U55_subset: + # Uses where which is not supported on Ethos-U55 + self.add_pass(DecomposeMaskedFill()) + self.add_pass(DecomposeFmodPass()) + self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeLinearVectorNormPass()) @@ -355,10 +363,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ReplaceInfValues()) self.add_pass(DecomposeSumPass()) - if not self.tosa_spec.is_U55_subset: - # Uses where which is not supported on Ethos-U55 - self.add_pass(DecomposeMaskedFill()) - return self._transform(graph_module) def __call__(self, module: Module) -> PassResult: diff --git a/backends/arm/_passes/decompose_fmod_pass.py b/backends/arm/_passes/decompose_fmod_pass.py new file mode 100644 index 00000000000..7e3d1072690 --- /dev/null +++ b/backends/arm/_passes/decompose_fmod_pass.py @@ -0,0 +1,79 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + +exir_op = (exir_ops.edge.aten.fmod.Tensor,) +aten_op = (torch.ops.aten.fmod.Tensor,) + + +def _get_decomposition(op) -> tuple: + if op in exir_op: + return ( + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.full_like.default, + ) + if op in aten_op: + return ( + torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.ceil.default, + torch.ops.aten.floor.default, + torch.ops.aten.where.self, + torch.ops.aten.lt.Tensor, + torch.ops.aten.full_like.default, + ) + + raise Exception(f"Unable to get decomposition for {op}") + + +class DecomposeFmodPass(ArmPass): + """ + Decomposes fmod operator according to the following formula: + fmod(x, y) = x - x.div(y, rounding_mode=truncated) * y + """ + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (exir_op + aten_op): + return super().call_operator(op, args, kwargs, meta, updated) + + sub_op, div_op, mul_op, ceil_op, floor_op, where_op, lt_op, full_like_op = ( + _get_decomposition(op) + ) + + x, y = args + + div = super().call_operator(div_op, (x, y), {}, meta, True) + + floor_round = super().call_operator(floor_op, (div,), {}, meta, True) + ceil_round = super().call_operator(ceil_op, (div,), {}, meta, True) + + # Create a mask to determine which values are negative + # and use it to select the appropriate rounding method + # If the value is negative, use ceil, otherwise use floor + zeros = super().call_operator(full_like_op, (div, 0.0), {}, meta, True) + mask = super().call_operator(lt_op, (div, zeros), {}, meta, True) + + rounded_values = super().call_operator( + where_op, (mask, ceil_round, floor_round), {}, meta, True + ) + + mul = super().call_operator(mul_op, (rounded_values, y), {}, meta, True) + + out = super().call_operator(sub_op, (x, mul), {}, meta, True) + + return out diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index f6ef056f677..991e5d102fc 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -38,6 +38,7 @@ exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.fmod.Scalar: exir_ops.edge.aten.fmod.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, @@ -53,6 +54,7 @@ torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.fmod.Scalar: torch.ops.aten.fmod.Tensor, } diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 86c53e4aff1..e64ed89ee68 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -217,6 +217,8 @@ def is_node_supported( exir_ops.edge.aten.addmm.default: None, exir_ops.edge.aten.glu.default: None, exir_ops.edge.aten.logit.default: None, + exir_ops.edge.aten.fmod.Scalar: None, + exir_ops.edge.aten.fmod.Tensor: None, } if node.target in needs_decomp_dict: diff --git a/backends/arm/test/ops/test_fmod.py b/backends/arm/test/ops/test_fmod.py new file mode 100644 index 00000000000..89714ee44a9 --- /dev/null +++ b/backends/arm/test/ops/test_fmod.py @@ -0,0 +1,255 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineINT, + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +input_t = Tuple[torch.Tensor] + +aten_op_tensor = "torch.ops.aten.fmod.Tensor" +aten_op_scalar = "torch.ops.aten.fmod.Scalar" +exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_fmod_Tensor" +exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_fmod_Scalar" + + +class Fmod(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor | float): + return torch.fmod(x, y) + + +test_data_scalar = { + "fmod_tensor_scalar_pos": lambda: ( + Fmod(), + torch.tensor([[10.0, 25.5], [-33.2, 4.4]]), + 3.0, + ), + "fmod_tensor_scalar_neg": lambda: ( + Fmod(), + torch.tensor([[10.0, -25.5], [-33.2, 4.4]]), + -5.0, + ), + "fmod_tensor_scalar_one": lambda: (Fmod(), torch.randn(2, 3, 4), 1.0), + "fmod_tensor_scalar_small_float": lambda: ( + Fmod(), + torch.tensor([0.123, -0.456, 0.789]), + 0.1, + ), + "fmod_tensor_scalar_large_float": lambda: ( + Fmod(), + torch.tensor([1e8, -1e9, 3.5e6]), + 1e6, + ), + "fmod_division_by_zero": lambda: (Fmod(), torch.randn(2, 3, 4), 0.0), +} + + +test_data_tensor = { + "fmod_zeros": lambda: ( + Fmod(), + torch.zeros(1, 10, 10, 10), + torch.ones(1, 10, 10, 10), + ), + "fmod_ones": lambda: (Fmod(), torch.ones(1, 10, 10, 10), torch.ones(1, 10, 10, 10)), + "fmod_rand": lambda: (Fmod(), torch.rand(10, 10) - 0.5, torch.rand(10, 10) + 0.5), + "fmod_randn_pos": lambda: ( + Fmod(), + torch.randn(1, 4, 4, 4) + 10, + torch.randn(1, 4, 4, 4) + 10, + ), + "fmod_randn_neg": lambda: ( + Fmod(), + torch.randn(1, 4, 4, 4) - 10, + torch.randn(1, 4, 4, 4) + 10, + ), + "fmod_broadcast": lambda: ( + Fmod(), + torch.tensor([[10.0, 20.0], [30.0, 40.0]]), + torch.tensor([3.0, 7.0]), + ), + "fmod_negative_divisor": lambda: ( + Fmod(), + torch.tensor([[10.0, -20.0], [-30.0, 40.0]]), + torch.tensor([[-3.0, -5.0], [-7.0, -6.0]]), + ), + "fmod_division_by_zero": lambda: ( + Fmod(), + torch.tensor([1.0, 2.0, 3.0]), + torch.tensor([0.0, 1.0, 2.0]), + ), + "fmod_mixed_signs": lambda: ( + Fmod(), + torch.tensor([-10.0, 20.0, -30.0]), + torch.tensor([3.0, -5.0, 7.0]), + ), + "fmod_scalar_tensor": lambda: (Fmod(), torch.tensor(10.0), torch.tensor(3.0)), + "fmod_large_values": lambda: ( + Fmod(), + torch.tensor([1e19, -1e21]), + torch.tensor([3.0, 5.0]), + ), + "fmod_small_values": lambda: ( + Fmod(), + torch.tensor([1e-10, -1e-12]), + torch.tensor([1e-5, 2e-5]), + ), +} + +xfails = {"fmod_division_by_zero": "Invalid inputs not handled"} + + +@common.parametrize("test_data", test_data_scalar) +def test_fmod_scalar_tosa_FP(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = TosaPipelineFP[input_t]( + module, (data_x, data_y), aten_op=aten_op_scalar, exir_op=exir_op_scalar + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor) +def test_fmod_tensor_tosa_FP(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = TosaPipelineFP[input_t]( + module, (data_x, data_y), aten_op=aten_op_tensor, exir_op=exir_op_tensor + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_scalar, xfails=xfails) +def test_fmod_scalar_tosa_INT(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = TosaPipelineINT[input_t]( + module, + (data_x, data_y), + aten_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor, xfails=xfails) +def test_fmod_tensor_tosa_INT(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = TosaPipelineINT[input_t]( + module, + (data_x, data_y), + aten_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_scalar) +@common.XfailIfNoCorstone300 +def test_fmod_scalar_u55_INT(test_data): + module, data_x, data_y = test_data() + pipeline = OpNotSupportedPipeline[input_t]( + module, + (data_x, data_y), + { + exir_op_tensor: 1, + }, + n_expected_delegates=0, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor) +@common.XfailIfNoCorstone300 +def test_fmod_tensor_u55_INT(test_data): + module, data_x, data_y = test_data() + pipeline = OpNotSupportedPipeline[input_t]( + module, + (data_x, data_y), + { + exir_op_tensor: 1, + }, + n_expected_delegates=0, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_scalar, xfails=xfails) +@common.XfailIfNoCorstone320 +def test_fmod_scalar_u85_INT(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = EthosU85PipelineINT[input_t](module, (data_x, data_y), aten_ops=[]) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor, xfails=xfails) +@common.XfailIfNoCorstone320 +def test_fmod_tensor_u85_INT(test_data: input_t): + module, data_x, data_y = test_data() + pipeline = EthosU85PipelineINT[input_t](module, (data_x, data_y), aten_ops=[]) + pipeline.run() + + +@common.parametrize("test_data", test_data_scalar) +@common.SkipIfNoModelConverter +def test_fmod_scalar_vgf_FP(test_data: Tuple): + module, data_x, data_y = test_data() + pipeline = VgfPipeline[input_t]( + module, + (data_x, data_y), + [], + [], + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_scalar, xfails=xfails) +@common.SkipIfNoModelConverter +def test_fmod_scalar_vgf_INT(test_data: Tuple): + module, data_x, data_y = test_data() + pipeline = VgfPipeline[input_t]( + module, + (data_x, data_y), + [], + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor) +@common.SkipIfNoModelConverter +def test_fmod_tensor_vgf_FP(test_data: Tuple): + module, data_x, data_y = test_data() + pipeline = VgfPipeline[input_t]( + module, + (data_x, data_y), + [], + [], + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_tensor, xfails=xfails) +@common.SkipIfNoModelConverter +def test_fmod_tensor_vgf_INT(test_data: Tuple): + module, data_x, data_y = test_data() + pipeline = VgfPipeline[input_t]( + module, + (data_x, data_y), + [], + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()