Skip to content

Commit a726363

Browse files
committed
Arm backend: Add support for floor_divide.default
Signed-off-by: Agrima Khare <[email protected]> Change-Id: I85153b1b245862a107e6469600e69c059477f417
1 parent 418c584 commit a726363

File tree

5 files changed

+221
-0
lines changed

5 files changed

+221
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .decompose_elu_pass import DecomposeEluPass # noqa
4343
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4444
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
45+
from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa
4546
from .decompose_gelu_pass import DecomposeGeluPass # noqa
4647
from .decompose_glu_pass import DecomposeGluPass # noqa
4748
from .decompose_grouped_conv import DecomposeGroupedConv # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DecomposeEluPass,
5252
DecomposeEmbeddingPass,
5353
DecomposeExpm1Pass,
54+
DecomposeFloorDividePass,
5455
DecomposeGeluPass,
5556
DecomposeGluPass,
5657
DecomposeGroupedConv,
@@ -242,6 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
242243
self.add_pass(CastBoolToInt8Pass())
243244
self.add_pass(DecomposeSinhPass())
244245
self.add_pass(DecomposeSignPass())
246+
self.add_pass(DecomposeFloorDividePass())
245247
self.add_pass(DecomposeDivTensorModePass())
246248
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
247249
self.add_pass(DecomposeEmbeddingPass())
@@ -333,6 +335,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
333335
self.add_pass(CastBoolToInt8Pass())
334336
self.add_pass(DecomposeSignPass())
335337
self.add_pass(DecomposeAddmmPass())
338+
self.add_pass(DecomposeFloorDividePass())
336339
self.add_pass(DecomposeDivTensorModePass())
337340
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
338341
self.add_pass(ScalarsToAttributePass())
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.decompose_div_tensor_mode import (
11+
DecomposeDivTensorModePass,
12+
)
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass
15+
16+
edge_floor_divide_ops = (exir_ops.edge.aten.floor_divide.default,)
17+
aten_floor_divide_ops = (torch.ops.aten.floor_divide.default,)
18+
19+
20+
def get_floor_divide_decomposition(op) -> tuple:
21+
"""
22+
Returns the decomposition of the given aten.floor_div operation into
23+
its equivalent TOSA-supported operations
24+
25+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
26+
is:
27+
floor_div(x, y) → div_tensor_mode(x, y, rounding_mode="floor")
28+
29+
Returns:
30+
A tuple (div_op,) corresponding to the appropriate operator overload for the input op.
31+
32+
Raises:
33+
RuntimeError: If the provided operator is not a supported floor_divide variant.
34+
"""
35+
36+
if op in edge_floor_divide_ops:
37+
return (exir_ops.edge.aten.div.Tensor_mode,)
38+
if op in aten_floor_divide_ops:
39+
return (torch.ops.aten.div.Tensor_mode,)
40+
41+
raise RuntimeError(f"Can't get floor_div decomposition for op {op}")
42+
43+
44+
class DecomposeFloorDividePass(ArmPass):
45+
"""
46+
Decomposes aten.floor_divide into aten.div.Tensor_mode with rounding_mode="floor".
47+
"""
48+
49+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeDivTensorModePass}
50+
51+
def call_operator(self, op, args, kwargs, meta):
52+
if op not in (edge_floor_divide_ops + aten_floor_divide_ops):
53+
return super().call_operator(op, args, kwargs, meta, updated=False)
54+
55+
(div_op,) = get_floor_divide_decomposition(op)
56+
57+
input = args[0]
58+
other = args[1]
59+
60+
div_node = super().call_operator(
61+
div_op, (input, other), {"rounding_mode": "floor"}, meta, updated=True
62+
)
63+
64+
return div_node

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@
227227
exir_ops.edge.aten.logit.default,
228228
exir_ops.edge.aten.acos.default,
229229
exir_ops.edge.aten.elu.default,
230+
exir_ops.edge.aten.floor_divide.default,
230231
}
231232

232233

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Tuple, Union
9+
10+
import torch
11+
from executorch.backends.arm.test import common
12+
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineINT,
15+
EthosU85PipelineINT,
16+
TosaPipelineFP,
17+
TosaPipelineINT,
18+
VgfPipeline,
19+
)
20+
21+
test_data_suite = {
22+
# (test_name, input, other)
23+
"op_floor_div_rank1_ones": lambda: (
24+
torch.ones(5),
25+
torch.ones(5),
26+
),
27+
"op_floor_div_rank1_rand": lambda: (
28+
torch.rand(5) * 5,
29+
torch.rand(5) * 5,
30+
),
31+
"op_floor_div_rank4_negative_ones": lambda: (
32+
(-1) * torch.ones(5, 10, 25, 20),
33+
torch.ones(5, 10, 25, 20),
34+
),
35+
"op_floor_div_rank4_ones_div_negative": lambda: (
36+
torch.ones(5, 10, 25, 20),
37+
(-1) * torch.ones(5, 10, 25, 20),
38+
),
39+
"op_floor_div_rank4_large_rand": lambda: (
40+
200 * torch.rand(5, 10, 25, 20),
41+
torch.rand(5, 10, 25, 20),
42+
),
43+
"op_floor_div_rank4_randn_mutltiple_broadcasts": lambda: (
44+
torch.randn(1, 4, 4, 1),
45+
torch.randn(1, 1, 4, 4),
46+
),
47+
"op_floor_div_rank4_randn_scalar": lambda: (
48+
torch.randn(1, 4, 4, 1),
49+
2,
50+
),
51+
}
52+
53+
54+
class FloorDivide(torch.nn.Module):
55+
aten_op = "torch.ops.aten.floor_divide.default"
56+
aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default", "aten.floor.default"]
57+
exir_op = "executorch_exir_dialects_edge__ops_aten_div_Tensor_mode"
58+
exir_ops_int = [
59+
"executorch_exir_dialects_edge__ops_aten_reciprocal_default",
60+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
61+
"executorch_exir_dialects_edge__ops_aten_floor_default",
62+
]
63+
64+
def forward(
65+
self,
66+
input_: Union[torch.Tensor, torch.types.Number],
67+
other_: Union[torch.Tensor, torch.types.Number],
68+
):
69+
return torch.floor_divide(input=input_, other=other_)
70+
71+
72+
input_t1 = Tuple[torch.Tensor, torch.Tensor]
73+
74+
75+
@common.parametrize("test_data", test_data_suite)
76+
def test_floor_divide_tosa_FP(test_data: input_t1):
77+
pipeline = TosaPipelineFP[input_t1](
78+
FloorDivide(),
79+
test_data(),
80+
FloorDivide.aten_op,
81+
FloorDivide.exir_op,
82+
use_to_edge_transform_and_lower=False,
83+
)
84+
pipeline.run()
85+
86+
87+
@common.parametrize("test_data", test_data_suite)
88+
def test_floor_divide_tosa_INT(test_data: input_t1):
89+
pipeline = TosaPipelineINT[input_t1](
90+
FloorDivide(),
91+
test_data(),
92+
aten_op=FloorDivide.aten_ops_int,
93+
exir_op=FloorDivide.exir_ops_int,
94+
use_to_edge_transform_and_lower=False,
95+
)
96+
pipeline.run()
97+
98+
99+
@common.parametrize("test_data", test_data_suite)
100+
@common.XfailIfNoCorstone300
101+
def test_floor_divide_u55_INT(test_data: input_t1):
102+
pipeline = EthosU55PipelineINT[input_t1](
103+
FloorDivide(),
104+
test_data(),
105+
aten_ops=FloorDivide.aten_ops_int,
106+
exir_ops=FloorDivide.exir_ops_int,
107+
run_on_fvp=True,
108+
use_to_edge_transform_and_lower=False,
109+
)
110+
pipeline.run()
111+
112+
113+
@common.parametrize("test_data", test_data_suite)
114+
@common.XfailIfNoCorstone320
115+
def test_floor_divide_u85_INT(test_data: input_t1):
116+
pipeline = EthosU85PipelineINT[input_t1](
117+
FloorDivide(),
118+
test_data(),
119+
aten_ops=FloorDivide.aten_ops_int,
120+
exir_ops=FloorDivide.exir_ops_int,
121+
run_on_fvp=True,
122+
use_to_edge_transform_and_lower=False,
123+
)
124+
pipeline.run()
125+
126+
127+
@common.parametrize("test_data", test_data_suite)
128+
@common.SkipIfNoModelConverter
129+
def test_floor_divide_vgf_FP(test_data: input_t1):
130+
pipeline = VgfPipeline[input_t1](
131+
FloorDivide(),
132+
test_data(),
133+
FloorDivide.aten_op,
134+
FloorDivide.exir_op,
135+
tosa_version="TOSA-1.0+FP",
136+
use_to_edge_transform_and_lower=False,
137+
)
138+
pipeline.run()
139+
140+
141+
@common.parametrize("test_data", test_data_suite)
142+
@common.SkipIfNoModelConverter
143+
def test_floor_divide_vgf_INT(test_data: input_t1):
144+
pipeline = VgfPipeline[input_t1](
145+
FloorDivide(),
146+
test_data(),
147+
aten_op=FloorDivide.aten_ops_int,
148+
exir_op=FloorDivide.exir_ops_int,
149+
tosa_version="TOSA-1.0+INT",
150+
use_to_edge_transform_and_lower=False,
151+
)
152+
pipeline.run()

0 commit comments

Comments
 (0)