Skip to content

Commit a373925

Browse files
Arm backend: Align handling of flaky arm unittests (#7669)
* Align handling of flaky arm unittests - Removes uses of torch_manual_seed which previously fixed the random state. - Adds pytest plugin pytest-rerunfailures to mark flaky tests. - Refactors flaky tests to use data generators in favor of pregenerated data, which ensures that data is randomized between reruns. - Updates layer_norm testcase to use same qtol value for TOSA/EthosU targets. Note that fixing the randomness may lead to that we will see more flakyness in CI, this will have to be adressed with the flaky mark on a case by case basis over time. * Fix expectedFailure
1 parent 4a8fb17 commit a373925

File tree

10 files changed

+189
-202
lines changed

10 files changed

+189
-202
lines changed

backends/arm/test/ops/test_bmm.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88

9-
from typing import Tuple
9+
from typing import Callable, Tuple
1010

1111
import pytest
1212

@@ -16,39 +16,37 @@
1616
from executorch.exir.backend.compile_spec_schema import CompileSpec
1717
from parameterized import parameterized
1818

19-
torch.manual_seed(1)
20-
2119

2220
class TestBMM(unittest.TestCase):
2321
"""Tests Batch MatMul"""
2422

2523
class BMM(torch.nn.Module):
26-
test_parameters = [
27-
(torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
28-
(torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
29-
(torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
30-
(10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
31-
(-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
24+
test_data_generators = [
25+
lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
26+
lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
27+
lambda: (torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
28+
lambda: (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
29+
lambda: (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
3230
]
3331

3432
def forward(self, x, y):
3533
return torch.bmm(x, y)
3634

3735
class MatMul(torch.nn.Module):
38-
test_parameters = [
39-
(torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
40-
(torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
36+
test_data_generators = [
37+
lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
38+
lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
4139
]
4240

4341
def forward(self, x, y):
4442
return torch.matmul(x, y)
4543

4644
class BMMSingleInput(torch.nn.Module):
47-
test_parameters = [
48-
(torch.rand(20, 3, 3),),
49-
(torch.rand(2, 128, 128),),
50-
(10000 * torch.randn(4, 25, 25),),
51-
(5 + 5 * torch.randn(3, 64, 64),),
45+
test_data_generators = [
46+
lambda: (torch.rand(20, 3, 3),),
47+
lambda: (torch.rand(2, 128, 128),),
48+
lambda: (10000 * torch.randn(4, 25, 25),),
49+
lambda: (5 + 5 * torch.randn(3, 64, 64),),
5250
]
5351

5452
def forward(self, x):
@@ -120,67 +118,74 @@ def _test_bmm_ethosu_BI_pipeline(
120118
if conftest.is_option_enabled("corstone_fvp"):
121119
tester.run_method_and_compare_outputs(inputs=test_data, qtol=1)
122120

123-
@parameterized.expand(BMM.test_parameters)
124-
def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
125-
test_data = (operand1, operand2)
121+
@parameterized.expand(BMM.test_data_generators)
122+
def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
123+
test_data = test_data_generator()
126124
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
127125

128-
@parameterized.expand(BMMSingleInput.test_parameters)
129-
def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor):
130-
test_data = (operand1,)
126+
@parameterized.expand(BMMSingleInput.test_data_generators)
127+
def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
128+
test_data = test_data_generator()
131129
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
132130

133-
@parameterized.expand(MatMul.test_parameters)
134-
def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
135-
test_data = (operand1, operand2)
131+
@parameterized.expand(MatMul.test_data_generators)
132+
def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
133+
test_data = test_data_generator()
136134
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
137135

138-
@parameterized.expand(MatMul.test_parameters)
139-
def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
140-
test_data = (operand1, operand2)
136+
@parameterized.expand(MatMul.test_data_generators)
137+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
138+
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
139+
test_data = test_data_generator()
141140
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
142141

143-
@parameterized.expand(BMM.test_parameters)
144-
def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
145-
test_data = (operand1, operand2)
142+
@parameterized.expand(BMM.test_data_generators)
143+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
144+
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
145+
test_data = test_data_generator()
146146
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
147147

148-
@parameterized.expand(BMMSingleInput.test_parameters)
149-
def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
150-
test_data = (operand1,)
148+
@parameterized.expand(BMMSingleInput.test_data_generators)
149+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
150+
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
151+
test_data = test_data_generator()
151152
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
152153

153-
@parameterized.expand(BMM.test_parameters)
154+
@parameterized.expand(BMM.test_data_generators)
154155
@pytest.mark.corstone_fvp
155156
@unittest.expectedFailure
156-
def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
157-
test_data = (operand1, operand2)
157+
def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]):
158+
test_data = test_data_generator()
158159
self._test_bmm_ethosu_BI_pipeline(
159160
self.BMM(), common.get_u55_compile_spec(), test_data
160161
)
161162

162-
@parameterized.expand(BMM.test_parameters)
163+
@parameterized.expand(BMM.test_data_generators)
163164
@pytest.mark.corstone_fvp
164-
def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
165-
test_data = (operand1, operand2)
165+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
166+
def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
167+
test_data = test_data_generator()
166168
self._test_bmm_ethosu_BI_pipeline(
167169
self.BMM(), common.get_u85_compile_spec(), test_data
168170
)
169171

170172
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
171-
@parameterized.expand(BMMSingleInput.test_parameters)
173+
@parameterized.expand(BMMSingleInput.test_data_generators)
172174
@pytest.mark.corstone_fvp
173175
@unittest.expectedFailure
174-
def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor):
175-
test_data = (operand1,)
176+
def test_bmm_single_input_u55_BI_xfails(
177+
self, test_data_generator: Callable[[], Tuple]
178+
):
179+
test_data = test_data_generator()
176180
self._test_bmm_ethosu_BI_pipeline(
177181
self.BMMSingleInput(), common.get_u55_compile_spec(), test_data
178182
)
179183

180-
@parameterized.expand(BMMSingleInput.test_parameters)
184+
@parameterized.expand(BMMSingleInput.test_data_generators)
181185
@pytest.mark.corstone_fvp
182-
def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor):
183-
test_data = (operand1,)
186+
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
187+
def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
188+
test_data = test_data_generator()
184189
self._test_bmm_ethosu_BI_pipeline(
185190
self.BMMSingleInput(), common.get_u85_compile_spec(), test_data
186191
)

backends/arm/test/ops/test_conv1d.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88

9-
from typing import List, Optional, Tuple, Union
9+
from typing import List, Tuple, Union
1010

1111
import pytest
1212

@@ -25,7 +25,6 @@ class Conv1d(torch.nn.Module):
2525

2626
def __init__(
2727
self,
28-
inputs: Optional[torch.Tensor] = None,
2928
length=8,
3029
nbr_conv=1, # Number of chained convs
3130
in_channels: Union[List, int, None] = None,
@@ -75,11 +74,10 @@ def __init__(
7574
if not isinstance(padding_mode, List):
7675
padding_mode = [padding_mode]
7776

78-
# Generate test data if not provided
79-
if inputs is None:
80-
self.inputs = (torch.randn(batches, in_channels[0], length).to(dtype),)
81-
else:
82-
self.inputs = (inputs,)
77+
self.batches = batches
78+
self.in_channels = in_channels
79+
self.length = length
80+
self.dtype = dtype
8381

8482
# Build chain of convs
8583
for i in range(self.nbr_convs):
@@ -100,7 +98,9 @@ def __init__(
10098
)
10199

102100
def get_inputs(self):
103-
return self.inputs
101+
return (
102+
torch.randn(self.batches, self.in_channels[0], self.length).to(self.dtype),
103+
)
104104

105105
def forward(self, x):
106106
for i in range(self.nbr_convs):

backends/arm/test/ops/test_conv2d.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88

9-
from typing import List, Optional, Tuple, Union
9+
from typing import List, Tuple, Union
1010

1111
import pytest
1212

@@ -25,7 +25,6 @@ class Conv2d(torch.nn.Module):
2525

2626
def __init__(
2727
self,
28-
inputs: Optional[torch.Tensor] = None,
2928
height=8,
3029
width=8,
3130
nbr_conv=1, # Number of chained convs
@@ -76,13 +75,11 @@ def __init__(
7675
if not isinstance(padding_mode, List):
7776
padding_mode = [padding_mode]
7877

79-
# Generate test data if not provided
80-
if inputs is None:
81-
self.inputs = (
82-
torch.randn(batches, in_channels[0], height, width).to(dtype),
83-
)
84-
else:
85-
self.inputs = (inputs,)
78+
self.batches = batches
79+
self.in_channels = in_channels
80+
self.height = height
81+
self.width = width
82+
self.dtype = dtype
8683

8784
# Build chain of convs
8885
for i in range(self.nbr_convs):
@@ -103,7 +100,11 @@ def __init__(
103100
)
104101

105102
def get_inputs(self):
106-
return self.inputs
103+
return (
104+
torch.randn(self.batches, self.in_channels[0], self.height, self.width).to(
105+
self.dtype
106+
),
107+
)
107108

108109
def forward(self, x):
109110
for i in range(self.nbr_convs):

backends/arm/test/ops/test_conv_combos.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,7 @@ def test_block_bottleneck_residual_tosa_MI(self):
353353
model = ComboBlockBottleneckResidual()
354354
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
355355

356-
# TODO: Investigate flakyness (MLTORCH-307)
357-
@unittest.skip(reason="Skiped due to flakyness (MLTORCH-307)")
356+
@pytest.mark.flaky # TODO: Investigate flakyness (MLTORCH-307)
358357
def test_block_bottleneck_residual_tosa_BI(self):
359358
model = ComboBlockBottleneckResidual()
360359
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def _test_dw_conv_ethos_BI_pipeline(
252252
def test_dw_conv_tosa_MI(self, test_name: str, model: torch.nn.Module):
253253
self._test_dw_conv_tosa_MI_pipeline(model, model.get_inputs())
254254

255-
# TODO: Investigate flakyness (MLTORCH-307)
256255
@parameterized.expand(testsuite_conv1d + testsuite_conv2d)
256+
@pytest.mark.flaky # TODO: Investigate flakyness (MLTORCH-307)
257257
def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
258258
self._test_dw_conv_tosa_BI_pipeline(model, model.get_inputs())
259259

backends/arm/test/ops/test_layer_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _test_layernorm_tosa_BI_pipeline(
109109
.partition()
110110
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
111111
.to_executorch()
112-
.run_method_and_compare_outputs(inputs=test_data)
112+
.run_method_and_compare_outputs(qtol=1, inputs=test_data)
113113
)
114114

115115
def _test_layernorm_ethosu_BI_pipeline(

0 commit comments

Comments
 (0)