Skip to content

Commit a77025e

Browse files
committed
Fix import
1 parent 43f561a commit a77025e

File tree

4 files changed

+48
-1
lines changed

4 files changed

+48
-1
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from mamba import Mamba
1+
from .mamba import Mamba

tests/test_wrappers/__init__.py

Whitespace-only changes.

tests/test_wrappers/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def inference_network():
6+
from bayesflow.networks import CouplingFlow
7+
8+
return CouplingFlow(depth=2)
9+
10+
11+
@pytest.fixture()
12+
def random_time_series():
13+
import keras
14+
15+
return keras.random.normal(shape=(2, 80, 2))
16+
17+
18+
@pytest.fixture()
19+
def mamba_summary_network():
20+
from bayesflow.wrappers.mamba import Mamba
21+
22+
return Mamba(summary_dim=4, feature_dims=(2, 2), state_dims=(4, 4), conv_dims=(8, 8))

tests/test_wrappers/test_mamba.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
import bayesflow as bf
4+
5+
6+
@pytest.mark.torch
7+
def test_mamba_summary(random_time_series, mamba_summary_network):
8+
out = mamba_summary_network(random_time_series)
9+
# Batch size 2, summary dim 4
10+
assert out.shape == (2, 4)
11+
12+
13+
@pytest.mark.torch
14+
def test_mamba_trains(random_time_series, inference_network, mamba_summary_network):
15+
workflow = bf.BasicWorkflow(
16+
inference_network=inference_network,
17+
summary_network=mamba_summary_network,
18+
inference_variables=["parameters"],
19+
summary_variables=["observables"],
20+
simulator=bf.simulators.SIR(),
21+
)
22+
23+
history = workflow.fit_online(epochs=2, batch_size=8, num_batches_per_epoch=2)
24+
assert "loss" in list(history.history.keys())
25+
assert len(history.history["loss"]) == 2

0 commit comments

Comments
 (0)