Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_sign_pass import DecomposeSignPass # noqa
Expand Down
7 changes: 5 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
DecomposeMaxPool2DPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeRemainderPass,
DecomposeRoundPass,
DecomposeSelectPass,
DecomposeSignPass,
Expand Down Expand Up @@ -240,8 +241,9 @@ def _tosa_FP_pipeline(
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(DecomposeRemainderPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand Down Expand Up @@ -331,9 +333,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeAddmmPass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(DecomposeRemainderPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
66 changes: 66 additions & 0 deletions backends/arm/_passes/decompose_remainder_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.decompose_div_tensor_mode import (
DecomposeDivTensorModePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass
from torch._ops import OpOverload

Op = OpOverload | EdgeOpOverload


def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]:
"""
Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided
remainder operator. The concrete ops depend on whether the remainder op is
the aten or edge variant.
"""
if op == exir_ops.edge.aten.remainder.Tensor:
return (
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
)
if op == torch.ops.aten.remainder.Tensor:
return (
torch.ops.aten.div.Tensor_mode,
torch.ops.aten.mul.Tensor,
torch.ops.aten.sub.Tensor,
)
raise RuntimeError(f"Can't get remainder decomposition ops for op {op}")


class DecomposeRemainderPass(ArmPass):
"""
Decompose the remainder operation into primitive arithmetic:
remainder(x, y) -> x - floor_div(x, y) * y
where floor_div(x, y) == div(x, y, rounding_mode=\"floor\").
"""

_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}

def call_operator(self, op, args, kwargs, meta, updated=False):
supported_ops = (
exir_ops.edge.aten.remainder.Tensor,
torch.ops.aten.remainder.Tensor,
)
if op not in supported_ops:
return super().call_operator(op, args, kwargs, meta, updated)

div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op)
x, y = args[0], args[1]

floor_div = super().call_operator(
div_op, (x, y), {"rounding_mode": "floor"}, meta, updated=True
)
product = super().call_operator(mul_op, (floor_div, y), {}, meta, updated=True)
return super().call_operator(sub_op, (x, product), {}, meta, updated=True)
2 changes: 2 additions & 0 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.remainder.Scalar: exir_ops.edge.aten.remainder.Tensor,
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
Expand All @@ -55,6 +56,7 @@
torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor,
torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor,
torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor,
torch.ops.aten.remainder.Scalar: torch.ops.aten.remainder.Tensor,
}

_fp_profile_ops: Dict[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.remainder.Tensor,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sub.Tensor,
Expand Down Expand Up @@ -185,6 +186,8 @@
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.remainder.Scalar,
exir_ops.edge.aten.remainder.Tensor,
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.rsqrt.default,
Expand Down
131 changes: 131 additions & 0 deletions backends/arm/test/ops/test_remainder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)


def _nonzero_float_tensor(*shape: int) -> torch.Tensor:
return torch.rand(*shape, dtype=torch.float32) * 5 + 0.1


class Remainder(torch.nn.Module):
input_t = Tuple[torch.Tensor | float, torch.Tensor | float]

test_cases = {
"rank2_tensors": lambda: (
torch.randn(2, 3) * 7,
_nonzero_float_tensor(2, 3),
),
"rank4_tensors": lambda: (
torch.randn(1, 4, 2, 3) * 7,
_nonzero_float_tensor(1, 4, 2, 3),
),
"broadcast": lambda: (
torch.randn(4, 5, 1),
_nonzero_float_tensor(1, 5, 6),
),
"scalar_rhs": lambda: (
torch.randn(1, 2, 3, 4),
0.25,
),
}

def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor:
return torch.remainder(x, y)


def _get_aten_op(test_data: Remainder.input_t):
if any(isinstance(x, float) for x in test_data):
return "torch.ops.aten.remainder.Scalar"
else:
return "torch.ops.aten.remainder.Tensor"


def _get_exir_op(test_data: Remainder.input_t):
if isinstance(test_data[1], float):
return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
else:
return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"


@common.parametrize("test_data", Remainder.test_cases)
def test_remainder_tosa_FP(test_data):
data = test_data()
pipeline = TosaPipelineFP[Remainder.input_t](
Remainder(),
data,
_get_aten_op(data),
_get_exir_op(data),
)
pipeline.run()


@common.parametrize("test_data", Remainder.test_cases)
def test_remainder_tosa_INT(test_data):
pipeline = TosaPipelineINT[Remainder.input_t](
Remainder(),
test_data(),
[],
)
pipeline.run()


@common.parametrize("test_data", Remainder.test_cases)
@common.XfailIfNoCorstone300
def test_remainder_u55_INT(test_data):
pipeline = EthosU55PipelineINT[Remainder.input_t](
Remainder(),
test_data(),
[],
)
pipeline.run()


@common.parametrize("test_data", Remainder.test_cases)
@common.XfailIfNoCorstone320
def test_remainder_u85_INT(test_data):
pipeline = EthosU85PipelineINT[Remainder.input_t](
Remainder(),
test_data(),
[],
)
pipeline.run()


@common.parametrize("test_data", Remainder.test_cases)
@common.SkipIfNoModelConverter
def test_remainder_vgf_FP(test_data):
data = test_data()
pipeline = VgfPipeline[Remainder.input_t](
Remainder(),
data,
_get_aten_op(data),
_get_exir_op(data),
tosa_version="TOSA-1.0+FP",
)
pipeline.run()


@common.parametrize("test_data", Remainder.test_cases)
@common.SkipIfNoModelConverter
def test_remainder_vgf_INT(test_data):
pipeline = VgfPipeline[Remainder.input_t](
Remainder(),
test_data(),
[],
tosa_version="TOSA-1.0+INT",
)
pipeline.run()
Loading