Skip to content

Commit c52d0bf

Browse files
Arm backend: Divide test_unary into test_floor and test_ceil (#12673)
This makes it easier to parse which operators have had their lowering implemented. Additionally, it becomes easier to add test cases that are specific to the respective operator. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent da97a0e commit c52d0bf

File tree

2 files changed

+115
-73
lines changed

2 files changed

+115
-73
lines changed

backends/arm/test/ops/test_ceil.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
input_t1 = Tuple[torch.Tensor]
18+
19+
20+
class Ceil(torch.nn.Module):
21+
def forward(self, x: torch.Tensor):
22+
return torch.ceil(x)
23+
24+
aten_op = "torch.ops.aten.ceil.default"
25+
exir_op = "executorch_exir_dialects_edge__ops_aten_ceil_default"
26+
27+
28+
zeros = torch.zeros(1, 10, 10, 10)
29+
ones = torch.ones(10, 10, 10)
30+
rand = torch.rand(10, 10) - 0.5
31+
randn_pos = torch.randn(1, 4, 4, 4) + 10
32+
randn_neg = torch.randn(1, 4, 4, 4) - 10
33+
ramp = torch.arange(-16, 16, 0.2)
34+
35+
test_data = {
36+
"ceil_zeros": lambda: (Ceil(), zeros),
37+
"ceil_ones": lambda: (Ceil(), ones),
38+
"ceil_rand": lambda: (Ceil(), rand),
39+
"ceil_randn_pos": lambda: (Ceil(), randn_pos),
40+
"ceil_randn_neg": lambda: (Ceil(), randn_neg),
41+
"ceil_ramp": lambda: (Ceil(), ramp),
42+
}
43+
44+
45+
@common.parametrize("test_data", test_data)
46+
def test_ceil_tosa_MI(test_data: input_t1):
47+
module, data = test_data()
48+
pipeline = TosaPipelineMI[input_t1](
49+
module,
50+
(data,),
51+
module.aten_op,
52+
module.exir_op,
53+
)
54+
pipeline.run()
55+
56+
57+
@common.parametrize("test_data", test_data)
58+
def test_ceil_tosa_BI(test_data: input_t1):
59+
module, data = test_data()
60+
pipeline = TosaPipelineBI[input_t1](
61+
module,
62+
(data,),
63+
module.aten_op,
64+
module.exir_op,
65+
atol=0.06,
66+
rtol=0.01,
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", test_data)
72+
@common.XfailIfNoCorstone300
73+
def test_ceil_u55_BI(test_data: input_t1):
74+
module, data = test_data()
75+
pipeline = EthosU55PipelineBI[input_t1](
76+
module,
77+
(data,),
78+
module.aten_op,
79+
module.exir_op,
80+
run_on_fvp=True,
81+
)
82+
pipeline.run()
83+
84+
85+
@common.parametrize("test_data", test_data)
86+
@common.XfailIfNoCorstone320
87+
def test_ceil_u85_BI(test_data: input_t1):
88+
module, data = test_data()
89+
pipeline = EthosU85PipelineBI[input_t1](
90+
module,
91+
(data,),
92+
module.aten_op,
93+
module.exir_op,
94+
run_on_fvp=True,
95+
)
96+
pipeline.run()

backends/arm/test/ops/test_unary.py renamed to backends/arm/test/ops/test_floor.py

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,13 @@
1414
TosaPipelineMI,
1515
)
1616

17-
18-
input_t1 = Tuple[torch.Tensor] # Input x
19-
20-
21-
class Ceil(torch.nn.Module):
22-
def forward(self, x: torch.Tensor):
23-
return torch.ceil(x)
24-
25-
op_name = "ceil"
26-
aten_op = "torch.ops.aten.ceil.default"
27-
exir_op = "executorch_exir_dialects_edge__ops_aten_ceil_default"
17+
input_t1 = Tuple[torch.Tensor]
2818

2919

3020
class Floor(torch.nn.Module):
3121
def forward(self, x: torch.Tensor):
3222
return torch.floor(x)
3323

34-
op_name = "floor"
3524
aten_op = "torch.ops.aten.floor.default"
3625
exir_op = "executorch_exir_dialects_edge__ops_aten_floor_default"
3726

@@ -43,77 +32,34 @@ def forward(self, x: torch.Tensor):
4332
randn_neg = torch.randn(1, 4, 4, 4) - 10
4433
ramp = torch.arange(-16, 16, 0.2)
4534

46-
4735
test_data = {
48-
"ceil_zeros": lambda: (
49-
Ceil(),
50-
zeros,
51-
),
52-
"floor_zeros": lambda: (
53-
Floor(),
54-
zeros,
55-
),
56-
"ceil_ones": lambda: (
57-
Ceil(),
58-
ones,
59-
),
60-
"floor_ones": lambda: (
61-
Floor(),
62-
ones,
63-
),
64-
"ceil_rand": lambda: (
65-
Ceil(),
66-
rand,
67-
),
68-
"floor_rand": lambda: (
69-
Floor(),
70-
rand,
71-
),
72-
"ceil_randn_pos": lambda: (
73-
Ceil(),
74-
randn_pos,
75-
),
76-
"floor_randn_pos": lambda: (
77-
Floor(),
78-
randn_pos,
79-
),
80-
"ceil_randn_neg": lambda: (
81-
Ceil(),
82-
randn_neg,
83-
),
84-
"floor_randn_neg": lambda: (
85-
Floor(),
86-
randn_neg,
87-
),
88-
"ceil_ramp": lambda: (
89-
Ceil(),
90-
ramp,
91-
),
92-
"floor_ramp": lambda: (
93-
Floor(),
94-
ramp,
95-
),
36+
"floor_zeros": lambda: (Floor(), zeros),
37+
"floor_ones": lambda: (Floor(), ones),
38+
"floor_rand": lambda: (Floor(), rand),
39+
"floor_randn_pos": lambda: (Floor(), randn_pos),
40+
"floor_randn_neg": lambda: (Floor(), randn_neg),
41+
"floor_ramp": lambda: (Floor(), ramp),
9642
}
9743

9844

9945
@common.parametrize("test_data", test_data)
100-
def test_unary_tosa_MI(test_data: input_t1):
101-
module, test_data = test_data()
46+
def test_floor_tosa_MI(test_data: input_t1):
47+
module, data = test_data()
10248
pipeline = TosaPipelineMI[input_t1](
10349
module,
104-
(test_data,),
50+
(data,),
10551
module.aten_op,
10652
module.exir_op,
10753
)
10854
pipeline.run()
10955

11056

11157
@common.parametrize("test_data", test_data)
112-
def test_unary_tosa_BI(test_data: input_t1):
113-
module, test_data = test_data()
58+
def test_floor_tosa_BI(test_data: input_t1):
59+
module, data = test_data()
11460
pipeline = TosaPipelineBI[input_t1](
11561
module,
116-
(test_data,),
62+
(data,),
11763
module.aten_op,
11864
module.exir_op,
11965
atol=0.06,
@@ -124,11 +70,11 @@ def test_unary_tosa_BI(test_data: input_t1):
12470

12571
@common.parametrize("test_data", test_data)
12672
@common.XfailIfNoCorstone300
127-
def test_unary_u55_BI(test_data: input_t1):
128-
module, test_data = test_data()
73+
def test_floor_u55_BI(test_data: input_t1):
74+
module, data = test_data()
12975
pipeline = EthosU55PipelineBI[input_t1](
13076
module,
131-
(test_data,),
77+
(data,),
13278
module.aten_op,
13379
module.exir_op,
13480
run_on_fvp=True,
@@ -138,11 +84,11 @@ def test_unary_u55_BI(test_data: input_t1):
13884

13985
@common.parametrize("test_data", test_data)
14086
@common.XfailIfNoCorstone320
141-
def test_unary_u85_BI(test_data: input_t1):
142-
module, test_data = test_data()
87+
def test_floor_u85_BI(test_data: input_t1):
88+
module, data = test_data()
14389
pipeline = EthosU85PipelineBI[input_t1](
14490
module,
145-
(test_data,),
91+
(data,),
14692
module.aten_op,
14793
module.exir_op,
14894
run_on_fvp=True,

0 commit comments

Comments
 (0)