|
3 | 3 | import bayesflow as bf |
4 | 4 |
|
5 | 5 |
|
6 | | -@pytest.mark.torch |
7 | | -def test_mamba_summary(random_time_series, mamba_summary_network): |
| 6 | +def should_skip(): |
| 7 | + import keras |
| 8 | + |
| 9 | + if keras.backend.backend() != "torch": |
| 10 | + return True, "Mamba tests can only be run on PyTorch." |
| 11 | + |
8 | 12 | import torch |
9 | 13 |
|
10 | 14 | if not torch.cuda.is_available(): |
11 | | - pytest.skip("This test requires a GPU environment.") |
| 15 | + return True, "Mamba tests can only be run on GPU." |
| 16 | + |
| 17 | + try: |
| 18 | + import mamba_ssm # noqa: F401 |
| 19 | + except ImportError: |
| 20 | + return True, "Could not import mamba." |
| 21 | + |
| 22 | + return False, None |
| 23 | + |
| 24 | + |
| 25 | +skip, reason = should_skip() |
12 | 26 |
|
| 27 | + |
| 28 | +@pytest.mark.skipif(skip, reason=reason) |
| 29 | +def test_mamba_summary(random_time_series, mamba_summary_network): |
13 | 30 | out = mamba_summary_network(random_time_series) |
14 | 31 | # Batch size 2, summary dim 4 |
15 | 32 | assert out.shape == (2, 4) |
16 | 33 |
|
17 | 34 |
|
18 | | -@pytest.mark.torch |
| 35 | +@pytest.mark.skipif(skip, reason=reason) |
19 | 36 | def test_mamba_trains(random_time_series, inference_network, mamba_summary_network): |
20 | | - import torch |
21 | | - |
22 | | - if not torch.cuda.is_available(): |
23 | | - pytest.skip("This test requires a GPU environment.") |
24 | | - |
25 | 37 | workflow = bf.BasicWorkflow( |
26 | 38 | inference_network=inference_network, |
27 | 39 | summary_network=mamba_summary_network, |
|
0 commit comments