Skip to content

Commit 103fe5f

Browse files
committed
actually skip mamba tests if environment is unsuitable
1 parent c8ea3dd commit 103fe5f

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

tests/test_wrappers/test_mamba.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,37 @@
33
import bayesflow as bf
44

55

6-
@pytest.mark.torch
7-
def test_mamba_summary(random_time_series, mamba_summary_network):
6+
def should_skip():
7+
import keras
8+
9+
if keras.backend.backend() != "torch":
10+
return True, "Mamba tests can only be run on PyTorch."
11+
812
import torch
913

1014
if not torch.cuda.is_available():
11-
pytest.skip("This test requires a GPU environment.")
15+
return True, "Mamba tests can only be run on GPU."
16+
17+
try:
18+
import mamba_ssm # noqa: F401
19+
except ImportError:
20+
return True, "Could not import mamba."
21+
22+
return False, None
23+
24+
25+
skip, reason = should_skip()
1226

27+
28+
@pytest.mark.skipif(skip, reason=reason)
29+
def test_mamba_summary(random_time_series, mamba_summary_network):
1330
out = mamba_summary_network(random_time_series)
1431
# Batch size 2, summary dim 4
1532
assert out.shape == (2, 4)
1633

1734

18-
@pytest.mark.torch
35+
@pytest.mark.skipif(skip, reason=reason)
1936
def test_mamba_trains(random_time_series, inference_network, mamba_summary_network):
20-
import torch
21-
22-
if not torch.cuda.is_available():
23-
pytest.skip("This test requires a GPU environment.")
24-
2537
workflow = bf.BasicWorkflow(
2638
inference_network=inference_network,
2739
summary_network=mamba_summary_network,

0 commit comments

Comments
 (0)