Skip to content

Commit 3dbc104

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Split tensor and scalar test cases for remainder (#15809)
There's a script backends/arm/scripts/parse_test_names.py that parses test output from pytest and displays operator test coverage based on that. The remainder operator tests were problematic because the test case names did not reflect the operator name correctly. The operator has two versions, namely remainder.Tensor and remainder.Scalar. Because of this, test case names should respectively be split into separate names. This patch updates the test case names such that the script can find them. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent 004bbd2 commit 3dbc104

File tree

1 file changed

+98
-30
lines changed

1 file changed

+98
-30
lines changed

backends/arm/test/ops/test_remainder.py

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ def _nonzero_float_tensor(*shape: int) -> torch.Tensor:
2424
class Remainder(torch.nn.Module):
2525
input_t = Tuple[torch.Tensor | float, torch.Tensor | float]
2626

27-
test_cases = {
27+
aten_op_tensor = "torch.ops.aten.remainder.Tensor"
28+
exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
29+
aten_op_scalar = "torch.ops.aten.remainder.Scalar"
30+
exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
31+
32+
test_cases_tensor = {
2833
"rank2_tensors": lambda: (
2934
torch.randn(2, 3) * 7,
3035
_nonzero_float_tensor(2, 3),
@@ -37,44 +42,49 @@ class Remainder(torch.nn.Module):
3742
torch.randn(4, 5, 1),
3843
_nonzero_float_tensor(1, 5, 6),
3944
),
40-
"scalar_rhs": lambda: (
45+
}
46+
47+
test_cases_scalar = {
48+
"scalar_pos": lambda: (
4149
torch.randn(1, 2, 3, 4),
4250
0.25,
4351
),
52+
"scalar_neg": lambda: (
53+
torch.randn(3, 4),
54+
-0.25,
55+
),
4456
}
4557

4658
def forward(self, x: torch.Tensor | float, y: torch.Tensor | float) -> torch.Tensor:
4759
return torch.remainder(x, y)
4860

4961

50-
def _get_aten_op(test_data: Remainder.input_t):
51-
if any(isinstance(x, float) for x in test_data):
52-
return "torch.ops.aten.remainder.Scalar"
53-
else:
54-
return "torch.ops.aten.remainder.Tensor"
55-
56-
57-
def _get_exir_op(test_data: Remainder.input_t):
58-
if isinstance(test_data[1], float):
59-
return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
60-
else:
61-
return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
62+
@common.parametrize("test_data", Remainder.test_cases_tensor)
63+
def test_remainder_tensor_tosa_FP(test_data):
64+
data = test_data()
65+
pipeline = TosaPipelineFP[Remainder.input_t](
66+
Remainder(),
67+
data,
68+
Remainder.aten_op_tensor,
69+
Remainder.exir_op_tensor,
70+
)
71+
pipeline.run()
6272

6373

64-
@common.parametrize("test_data", Remainder.test_cases)
65-
def test_remainder_tosa_FP(test_data):
74+
@common.parametrize("test_data", Remainder.test_cases_scalar)
75+
def test_remainder_scalar_tosa_FP(test_data):
6676
data = test_data()
6777
pipeline = TosaPipelineFP[Remainder.input_t](
6878
Remainder(),
6979
data,
70-
_get_aten_op(data),
71-
_get_exir_op(data),
80+
Remainder.aten_op_scalar,
81+
Remainder.exir_op_scalar,
7282
)
7383
pipeline.run()
7484

7585

76-
@common.parametrize("test_data", Remainder.test_cases)
77-
def test_remainder_tosa_INT(test_data):
86+
@common.parametrize("test_data", Remainder.test_cases_tensor)
87+
def test_remainder_tensor_tosa_INT(test_data):
7888
pipeline = TosaPipelineINT[Remainder.input_t](
7989
Remainder(),
8090
test_data(),
@@ -83,9 +93,30 @@ def test_remainder_tosa_INT(test_data):
8393
pipeline.run()
8494

8595

86-
@common.parametrize("test_data", Remainder.test_cases)
96+
@common.parametrize("test_data", Remainder.test_cases_scalar)
97+
def test_remainder_scalar_tosa_INT(test_data):
98+
pipeline = TosaPipelineINT[Remainder.input_t](
99+
Remainder(),
100+
test_data(),
101+
[],
102+
)
103+
pipeline.run()
104+
105+
106+
@common.parametrize("test_data", Remainder.test_cases_tensor)
107+
@common.XfailIfNoCorstone300
108+
def test_remainder_tensor_u55_INT(test_data):
109+
pipeline = EthosU55PipelineINT[Remainder.input_t](
110+
Remainder(),
111+
test_data(),
112+
[],
113+
)
114+
pipeline.run()
115+
116+
117+
@common.parametrize("test_data", Remainder.test_cases_scalar)
87118
@common.XfailIfNoCorstone300
88-
def test_remainder_u55_INT(test_data):
119+
def test_remainder_scalar_u55_INT(test_data):
89120
pipeline = EthosU55PipelineINT[Remainder.input_t](
90121
Remainder(),
91122
test_data(),
@@ -94,9 +125,9 @@ def test_remainder_u55_INT(test_data):
94125
pipeline.run()
95126

96127

97-
@common.parametrize("test_data", Remainder.test_cases)
128+
@common.parametrize("test_data", Remainder.test_cases_tensor)
98129
@common.XfailIfNoCorstone320
99-
def test_remainder_u85_INT(test_data):
130+
def test_remainder_tensor_u85_INT(test_data):
100131
pipeline = EthosU85PipelineINT[Remainder.input_t](
101132
Remainder(),
102133
test_data(),
@@ -105,23 +136,60 @@ def test_remainder_u85_INT(test_data):
105136
pipeline.run()
106137

107138

108-
@common.parametrize("test_data", Remainder.test_cases)
139+
@common.parametrize("test_data", Remainder.test_cases_scalar)
140+
@common.XfailIfNoCorstone320
141+
def test_remainder_scalar_u85_INT(test_data):
142+
pipeline = EthosU85PipelineINT[Remainder.input_t](
143+
Remainder(),
144+
test_data(),
145+
[],
146+
)
147+
pipeline.run()
148+
149+
150+
@common.parametrize("test_data", Remainder.test_cases_tensor)
109151
@common.SkipIfNoModelConverter
110-
def test_remainder_vgf_FP(test_data):
152+
def test_remainder_tensor_vgf_FP(test_data):
111153
data = test_data()
112154
pipeline = VgfPipeline[Remainder.input_t](
113155
Remainder(),
114156
data,
115-
_get_aten_op(data),
116-
_get_exir_op(data),
157+
Remainder.aten_op_tensor,
158+
Remainder.exir_op_tensor,
117159
tosa_version="TOSA-1.0+FP",
118160
)
119161
pipeline.run()
120162

121163

122-
@common.parametrize("test_data", Remainder.test_cases)
164+
@common.parametrize("test_data", Remainder.test_cases_scalar)
165+
@common.SkipIfNoModelConverter
166+
def test_remainder_scalar_vgf_FP(test_data):
167+
data = test_data()
168+
pipeline = VgfPipeline[Remainder.input_t](
169+
Remainder(),
170+
data,
171+
Remainder.aten_op_scalar,
172+
Remainder.exir_op_scalar,
173+
tosa_version="TOSA-1.0+FP",
174+
)
175+
pipeline.run()
176+
177+
178+
@common.parametrize("test_data", Remainder.test_cases_tensor)
179+
@common.SkipIfNoModelConverter
180+
def test_remainder_tensor_vgf_INT(test_data):
181+
pipeline = VgfPipeline[Remainder.input_t](
182+
Remainder(),
183+
test_data(),
184+
[],
185+
tosa_version="TOSA-1.0+INT",
186+
)
187+
pipeline.run()
188+
189+
190+
@common.parametrize("test_data", Remainder.test_cases_scalar)
123191
@common.SkipIfNoModelConverter
124-
def test_remainder_vgf_INT(test_data):
192+
def test_remainder_scalar_vgf_INT(test_data):
125193
pipeline = VgfPipeline[Remainder.input_t](
126194
Remainder(),
127195
test_data(),

0 commit comments

Comments
 (0)