diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index deacfb7ec6f..87ac8910c17 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -58,6 +58,7 @@ from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_remainder_pass import DecomposeRemainderPass # noqa from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_sign_pass import DecomposeSignPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d6e63100603..3a9143488b6 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -62,6 +62,7 @@ DecomposeMaxPool2DPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeRemainderPass, DecomposeRoundPass, DecomposeSelectPass, DecomposeSignPass, @@ -240,8 +241,9 @@ def _tosa_FP_pipeline( self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSinhPass()) self.add_pass(DecomposeSignPass()) - self.add_pass(DecomposeDivTensorModePass()) self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(DecomposeRemainderPass()) + self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -331,9 +333,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(CastBoolToInt8Pass()) self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeAddmmPass()) + self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(DecomposeRemainderPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(ReplaceScalarWithTensorByProfilePass()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) diff --git a/backends/arm/_passes/decompose_remainder_pass.py b/backends/arm/_passes/decompose_remainder_pass.py new file mode 100644 index 00000000000..ac37eae86df --- /dev/null +++ b/backends/arm/_passes/decompose_remainder_pass.py @@ -0,0 +1,66 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.decompose_div_tensor_mode import ( + DecomposeDivTensorModePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverload + +Op = OpOverload | EdgeOpOverload + + +def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]: + """ + Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided + remainder operator. The concrete ops depend on whether the remainder op is + the aten or edge variant. + """ + if op == exir_ops.edge.aten.remainder.Tensor: + return ( + exir_ops.edge.aten.div.Tensor_mode, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + ) + if op == torch.ops.aten.remainder.Tensor: + return ( + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sub.Tensor, + ) + raise RuntimeError(f"Can't get remainder decomposition ops for op {op}") + + +class DecomposeRemainderPass(ArmPass): + """ + Decompose the remainder operation into primitive arithmetic: + remainder(x, y) -> x - floor_div(x, y) * y + where floor_div(x, y) == div(x, y, rounding_mode=\"floor\"). + """ + + _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass} + + def call_operator(self, op, args, kwargs, meta, updated=False): + supported_ops = ( + exir_ops.edge.aten.remainder.Tensor, + torch.ops.aten.remainder.Tensor, + ) + if op not in supported_ops: + return super().call_operator(op, args, kwargs, meta, updated) + + div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op) + x, y = args[0], args[1] + + floor_div = super().call_operator( + div_op, (x, y), {"rounding_mode": "floor"}, meta, updated=True + ) + product = super().call_operator(mul_op, (floor_div, y), {}, meta, updated=True) + return super().call_operator(sub_op, (x, product), {}, meta, updated=True) diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index 579ac825e9e..f5ab5f633ba 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -40,6 +40,7 @@ exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.remainder.Scalar: exir_ops.edge.aten.remainder.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, @@ -55,6 +56,7 @@ torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor, torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor, torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.remainder.Scalar: torch.ops.aten.remainder.Tensor, } _fp_profile_ops: Dict[ diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index b91ed4fb130..61f51165a33 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -77,6 +77,7 @@ exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.sub.Tensor, @@ -185,6 +186,8 @@ exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.remainder.Scalar, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, diff --git a/backends/arm/test/ops/test_remainder.py b/backends/arm/test/ops/test_remainder.py new file mode 100644 index 00000000000..2cba9532cde --- /dev/null +++ b/backends/arm/test/ops/test_remainder.py @@ -0,0 +1,131 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +def _nonzero_float_tensor(*shape: int) -> torch.Tensor: + return torch.rand(*shape, dtype=torch.float32) * 5 + 0.1 + + +class Remainder(torch.nn.Module): + input_t = Tuple[torch.Tensor | float, torch.Tensor | float] + + test_cases = { + "rank2_tensors": lambda: ( + torch.randn(2, 3) * 7, + _nonzero_float_tensor(2, 3), + ), + "rank4_tensors": lambda: ( + torch.randn(1, 4, 2, 3) * 7, + _nonzero_float_tensor(1, 4, 2, 3), + ), + "broadcast": lambda: ( + torch.randn(4, 5, 1), + _nonzero_float_tensor(1, 5, 6), + ), + "scalar_rhs": lambda: ( + torch.randn(1, 2, 3, 4), + 0.25, + ), + } + + def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor: + return torch.remainder(x, y) + + +def _get_aten_op(test_data: Remainder.input_t): + if any(isinstance(x, float) for x in test_data): + return "torch.ops.aten.remainder.Scalar" + else: + return "torch.ops.aten.remainder.Tensor" + + +def _get_exir_op(test_data: Remainder.input_t): + if isinstance(test_data[1], float): + return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar" + else: + return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor" + + +@common.parametrize("test_data", Remainder.test_cases) +def test_remainder_tosa_FP(test_data): + data = test_data() + pipeline = TosaPipelineFP[Remainder.input_t]( + Remainder(), + data, + _get_aten_op(data), + _get_exir_op(data), + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases) +def test_remainder_tosa_INT(test_data): + pipeline = TosaPipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases) +@common.XfailIfNoCorstone300 +def test_remainder_u55_INT(test_data): + pipeline = EthosU55PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases) +@common.XfailIfNoCorstone320 +def test_remainder_u85_INT(test_data): + pipeline = EthosU85PipelineINT[Remainder.input_t]( + Remainder(), + test_data(), + [], + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases) +@common.SkipIfNoModelConverter +def test_remainder_vgf_FP(test_data): + data = test_data() + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + data, + _get_aten_op(data), + _get_exir_op(data), + tosa_version="TOSA-1.0+FP", + ) + pipeline.run() + + +@common.parametrize("test_data", Remainder.test_cases) +@common.SkipIfNoModelConverter +def test_remainder_vgf_INT(test_data): + pipeline = VgfPipeline[Remainder.input_t]( + Remainder(), + test_data(), + [], + tosa_version="TOSA-1.0+INT", + ) + pipeline.run()