Skip to content

Commit c8ea3dd

Browse files
committed
revert trying to test mamba on workflow, since workflows do not have a GPU anyway
1 parent 6147534 commit c8ea3dd

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

.github/workflows/tests.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,6 @@ jobs:
6161
run: |
6262
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
6363
64-
# we currently have extra dependencies for torch tests
65-
# TODO: can we install this dynamically? should we add it to the test dependencies?
66-
- name: Install Extra Dependencies
67-
if: ${{ matrix.backend == 'torch' }}
68-
run: |
69-
pip install -U mamba-ssm
70-
7164
- name: Show Environment Info
7265
run: |
7366
python --version

tests/test_wrappers/test_mamba.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@
55

66
@pytest.mark.torch
77
def test_mamba_summary(random_time_series, mamba_summary_network):
8+
import torch
9+
10+
if not torch.cuda.is_available():
11+
pytest.skip("This test requires a GPU environment.")
12+
813
out = mamba_summary_network(random_time_series)
914
# Batch size 2, summary dim 4
1015
assert out.shape == (2, 4)
1116

1217

1318
@pytest.mark.torch
1419
def test_mamba_trains(random_time_series, inference_network, mamba_summary_network):
20+
import torch
21+
22+
if not torch.cuda.is_available():
23+
pytest.skip("This test requires a GPU environment.")
24+
1525
workflow = bf.BasicWorkflow(
1626
inference_network=inference_network,
1727
summary_network=mamba_summary_network,

0 commit comments

Comments
 (0)