Skip to content

Commit f28d5fc

Browse files
authored
test: Add xfail_if_cuda marker (#431)
* Add possibility to mark test as xfail_if_cuda * Mark WithRNN as xfail_if_cuda
1 parent 44e91ad commit f28d5fc

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

tests/unit/autogram/test_engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,13 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
158158

159159
@mark.parametrize(
160160
"architecture",
161-
[WithBatchNorm, WithSideEffect, Randomness, WithModuleTrackingRunningStats, WithRNN],
161+
[
162+
WithBatchNorm,
163+
WithSideEffect,
164+
Randomness,
165+
WithModuleTrackingRunningStats,
166+
param(WithRNN, marks=mark.xfail_if_cuda),
167+
],
162168
)
163169
@mark.parametrize("batch_size", [1, 3, 32])
164170
@mark.parametrize("batch_dim", [param(0, marks=mark.xfail), None])

tests/unit/conftest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def pytest_addoption(parser):
3737

3838
def pytest_configure(config):
3939
config.addinivalue_line("markers", "slow: mark test as slow to run")
40+
config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda")
4041

4142

4243
def pytest_collection_modifyitems(config, items):
43-
if config.getoption("--runslow"):
44-
return
45-
4644
skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.")
45+
xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}")
4746
for item in items:
48-
if "slow" in item.keywords:
47+
if "slow" in item.keywords and not config.getoption("--runslow"):
4948
item.add_marker(skip_slow)
49+
if "xfail_if_cuda" in item.keywords and str(DEVICE).startswith("cuda"):
50+
item.add_marker(xfail_cuda)

0 commit comments

Comments
 (0)