Skip to content

Commit 66ba13f

Browse files
martinlsmMartin Lindström
authored andcommitted
Arm backend: Add support for torch.remainder (pytorch#15409)
Add DecomposeRemainderPass to express tensor remainder via floor div, multiplication, and subtraction. Test for the new operator has been added in backends/arm/test/ops/test_remainder.py. Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent da7cd68 commit 66ba13f

File tree

6 files changed

+208
-1
lines changed

6 files changed

+208
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
6060
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
6161
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
62+
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
6263
from .decompose_round_pass import DecomposeRoundPass # noqa
6364
from .decompose_select import DecomposeSelectPass # noqa
6465
from .decompose_sign_pass import DecomposeSignPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
DecomposeMaxPool2DPass,
6464
DecomposeMeanDimPass,
6565
DecomposeNotEqualPass,
66+
DecomposeRemainderPass,
6667
DecomposeRoundPass,
6768
DecomposeSelectPass,
6869
DecomposeSignPass,
@@ -244,6 +245,8 @@ def _tosa_FP_pipeline(
244245
self.add_pass(DecomposeFloorDividePass())
245246
self.add_pass(DecomposeDivTensorModePass())
246247
self.add_pass(ReplaceScalarWithTensorByProfilePass())
248+
self.add_pass(DecomposeRemainderPass())
249+
self.add_pass(DecomposeDivTensorModePass())
247250
self.add_pass(DecomposeEmbeddingPass())
248251
self.add_pass(FuseQuantizedActivationPass())
249252
self.add_pass(RemoveGetItemPass())
@@ -334,9 +337,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
334337
self.add_pass(DecomposeSignPass())
335338
self.add_pass(DecomposeAddmmPass())
336339
self.add_pass(DecomposeFloorDividePass())
340+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
341+
self.add_pass(DecomposeRemainderPass())
337342
self.add_pass(DecomposeDivTensorModePass())
338343
self.add_pass(DecomposeAddSubAlphaPass())
339-
self.add_pass(ReplaceScalarWithTensorByProfilePass())
340344
self.add_pass(ScalarsToAttributePass())
341345
self.add_pass(DecomposeGroupNormPass())
342346
self.add_pass(DecomposeLayerNormPass())
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.decompose_div_tensor_mode import (
11+
DecomposeDivTensorModePass,
12+
)
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
15+
from executorch.exir.pass_base import ExportPass
16+
from torch._ops import OpOverload
17+
18+
Op = OpOverload | EdgeOpOverload
19+
20+
21+
def _get_remainder_decomposition_ops(op: Op) -> tuple[Op, Op, Op]:
22+
"""
23+
Returns the (div_mode_op, mul_op, sub_op) needed to lower the provided
24+
remainder operator. The concrete ops depend on whether the remainder op is
25+
the aten or edge variant.
26+
"""
27+
if op == exir_ops.edge.aten.remainder.Tensor:
28+
return (
29+
exir_ops.edge.aten.div.Tensor_mode,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.sub.Tensor,
32+
)
33+
if op == torch.ops.aten.remainder.Tensor:
34+
return (
35+
torch.ops.aten.div.Tensor_mode,
36+
torch.ops.aten.mul.Tensor,
37+
torch.ops.aten.sub.Tensor,
38+
)
39+
raise RuntimeError(f"Can't get remainder decomposition ops for op {op}")
40+
41+
42+
class DecomposeRemainderPass(ArmPass):
43+
"""
44+
Decompose the remainder operation into primitive arithmetic:
45+
remainder(x, y) -> x - floor_div(x, y) * y
46+
where floor_div(x, y) == div(x, y, rounding_mode=\"floor\").
47+
"""
48+
49+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}
50+
51+
def call_operator(self, op, args, kwargs, meta, updated=False):
52+
supported_ops = (
53+
exir_ops.edge.aten.remainder.Tensor,
54+
torch.ops.aten.remainder.Tensor,
55+
)
56+
if op not in supported_ops:
57+
return super().call_operator(op, args, kwargs, meta, updated)
58+
59+
div_op, mul_op, sub_op = _get_remainder_decomposition_ops(op)
60+
x, y = args[0], args[1]
61+
62+
floor_div = super().call_operator(
63+
div_op, (x, y), {"rounding_mode": "floor"}, meta, updated=True
64+
)
65+
product = super().call_operator(mul_op, (floor_div, y), {}, meta, updated=True)
66+
return super().call_operator(sub_op, (x, product), {}, meta, updated=True)

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor,
4141
exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor,
4242
exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor,
43+
exir_ops.edge.aten.remainder.Scalar: exir_ops.edge.aten.remainder.Tensor,
4344
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
4445
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
4546
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
@@ -55,6 +56,7 @@
5556
torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor,
5657
torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor,
5758
torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor,
59+
torch.ops.aten.remainder.Scalar: torch.ops.aten.remainder.Tensor,
5860
}
5961

6062
_fp_profile_ops: Dict[

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
exir_ops.edge.aten.repeat.default,
7878
exir_ops.edge.aten.reciprocal.default,
7979
exir_ops.edge.aten.relu.default,
80+
exir_ops.edge.aten.remainder.Tensor,
8081
exir_ops.edge.aten.rsqrt.default,
8182
exir_ops.edge.aten.select_copy.int,
8283
exir_ops.edge.aten.sub.Tensor,
@@ -185,6 +186,8 @@
185186
exir_ops.edge.aten.repeat.default,
186187
exir_ops.edge.aten.reciprocal.default,
187188
exir_ops.edge.aten.relu.default,
189+
exir_ops.edge.aten.remainder.Scalar,
190+
exir_ops.edge.aten.remainder.Tensor,
188191
exir_ops.edge.aten.leaky_relu.default,
189192
exir_ops.edge.aten.sqrt.default,
190193
exir_ops.edge.aten.rsqrt.default,
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
20+
def _nonzero_float_tensor(*shape: int) -> torch.Tensor:
21+
return torch.rand(*shape, dtype=torch.float32) * 5 + 0.1
22+
23+
24+
class Remainder(torch.nn.Module):
25+
input_t = Tuple[torch.Tensor | float, torch.Tensor | float]
26+
27+
test_cases = {
28+
"rank2_tensors": lambda: (
29+
torch.randn(2, 3) * 7,
30+
_nonzero_float_tensor(2, 3),
31+
),
32+
"rank4_tensors": lambda: (
33+
torch.randn(1, 4, 2, 3) * 7,
34+
_nonzero_float_tensor(1, 4, 2, 3),
35+
),
36+
"broadcast": lambda: (
37+
torch.randn(4, 5, 1),
38+
_nonzero_float_tensor(1, 5, 6),
39+
),
40+
"scalar_rhs": lambda: (
41+
torch.randn(1, 2, 3, 4),
42+
0.25,
43+
),
44+
}
45+
46+
def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor:
47+
return torch.remainder(x, y)
48+
49+
50+
def _get_aten_op(test_data: Remainder.input_t):
51+
if any(isinstance(x, float) for x in test_data):
52+
return "torch.ops.aten.remainder.Scalar"
53+
else:
54+
return "torch.ops.aten.remainder.Tensor"
55+
56+
57+
def _get_exir_op(test_data: Remainder.input_t):
58+
if isinstance(test_data[1], float):
59+
return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
60+
else:
61+
return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
62+
63+
64+
@common.parametrize("test_data", Remainder.test_cases)
65+
def test_remainder_tosa_FP(test_data):
66+
data = test_data()
67+
pipeline = TosaPipelineFP[Remainder.input_t](
68+
Remainder(),
69+
data,
70+
_get_aten_op(data),
71+
_get_exir_op(data),
72+
)
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", Remainder.test_cases)
77+
def test_remainder_tosa_INT(test_data):
78+
pipeline = TosaPipelineINT[Remainder.input_t](
79+
Remainder(),
80+
test_data(),
81+
[],
82+
)
83+
pipeline.run()
84+
85+
86+
@common.parametrize("test_data", Remainder.test_cases)
87+
@common.XfailIfNoCorstone300
88+
def test_remainder_u55_INT(test_data):
89+
pipeline = EthosU55PipelineINT[Remainder.input_t](
90+
Remainder(),
91+
test_data(),
92+
[],
93+
)
94+
pipeline.run()
95+
96+
97+
@common.parametrize("test_data", Remainder.test_cases)
98+
@common.XfailIfNoCorstone320
99+
def test_remainder_u85_INT(test_data):
100+
pipeline = EthosU85PipelineINT[Remainder.input_t](
101+
Remainder(),
102+
test_data(),
103+
[],
104+
)
105+
pipeline.run()
106+
107+
108+
@common.parametrize("test_data", Remainder.test_cases)
109+
@common.SkipIfNoModelConverter
110+
def test_remainder_vgf_FP(test_data):
111+
data = test_data()
112+
pipeline = VgfPipeline[Remainder.input_t](
113+
Remainder(),
114+
data,
115+
_get_aten_op(data),
116+
_get_exir_op(data),
117+
tosa_version="TOSA-1.0+FP",
118+
)
119+
pipeline.run()
120+
121+
122+
@common.parametrize("test_data", Remainder.test_cases)
123+
@common.SkipIfNoModelConverter
124+
def test_remainder_vgf_INT(test_data):
125+
pipeline = VgfPipeline[Remainder.input_t](
126+
Remainder(),
127+
test_data(),
128+
[],
129+
tosa_version="TOSA-1.0+INT",
130+
)
131+
pipeline.run()

0 commit comments

Comments
 (0)