Skip to content

Commit 7d8da19

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Mark test in test_bmm.py as flaky (#14748)
Add a new argument to `common.parametrize`, `flakies`, which selects which parametrized test cases to mark as flaky. With this new argument, mark the test test_bmm.py::test_bmm_vgf_FP_single_input[rand_big_1] as flaky. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Martin Lindström <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent 8ac6300 commit 7d8da19

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

backends/arm/test/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def parametrize(
227227
test_data: dict[str, Any],
228228
xfails: dict[str, xfail_type] | None = None,
229229
strict: bool = True,
230+
flakies: dict[str, int] | None = None,
230231
):
231232
"""
232233
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
@@ -237,12 +238,17 @@ def parametrize(
237238
"""
238239
if xfails is None:
239240
xfails = {}
241+
if flakies is None:
242+
flakies = {}
240243

241244
def decorator_func(func):
242245
"""Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function"""
243246
pytest_testsuite = []
244247
for id, test_parameters in test_data.items():
245-
if id in xfails:
248+
if id in flakies:
249+
# Mark this parameter as flaky with given reruns
250+
marker = (pytest.mark.flaky(reruns=flakies[id]),)
251+
elif id in xfails:
246252
xfail_info = xfails[id]
247253
reason = ""
248254
raises = None

backends/arm/test/ops/test_bmm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def test_bmm_vgf_FP(test_data: input_t1):
146146
pipeline.run()
147147

148148

149-
@common.parametrize("test_data", BMMSingleInput.test_data_generators)
149+
@common.parametrize(
150+
"test_data",
151+
BMMSingleInput.test_data_generators,
152+
flakies={"rand_big_1": 3},
153+
)
150154
@common.SkipIfNoModelConverter
151155
def test_bmm_vgf_FP_single_input(test_data: input_t1):
152156
pipeline = VgfPipeline[input_t1](

0 commit comments

Comments
 (0)