Skip to content

Commit 3bb7ee5

Browse files
committed
fix two moons test
1 parent 382afef commit 3bb7ee5

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

tests/test_networks/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ def generative_inference_network(request):
119119

120120

121121
@pytest.fixture(scope="function")
122-
def lst_net(summary_dim):
123-
from bayesflow.networks import LSTNet
122+
def time_series_network(summary_dim):
123+
from bayesflow.networks import TimeSeriesNetwork
124124

125-
return LSTNet(summary_dim=summary_dim)
125+
return TimeSeriesNetwork(summary_dim=summary_dim)
126126

127127

128128
@pytest.fixture(scope="function")
@@ -139,7 +139,7 @@ def deep_set(summary_dim):
139139
return DeepSet(summary_dim=summary_dim)
140140

141141

142-
@pytest.fixture(params=[None, "lst_net", "set_transformer", "deep_set"], scope="function")
142+
@pytest.fixture(params=[None, "time_series_network", "set_transformer", "deep_set"], scope="function")
143143
def summary_network(request, summary_dim):
144144
if request.param is None:
145145
return None

tests/test_simulators/test_simulators.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@ def test_two_moons(simulator, batch_size):
66
samples = simulator.sample((batch_size,))
77

88
assert isinstance(samples, dict)
9-
assert list(samples.keys()) == ["r", "alpha", "theta", "x"]
9+
assert list(samples.keys()) == ["parameters", "observables"]
1010
assert all(isinstance(value, np.ndarray) for value in samples.values())
1111

12-
assert samples["r"].shape == (batch_size, 1)
13-
assert samples["alpha"].shape == (batch_size, 1)
14-
assert samples["theta"].shape == (batch_size, 2)
15-
assert samples["x"].shape == (batch_size, 2)
12+
assert samples["parameters"].shape == (batch_size, 2)
13+
assert samples["observables"].shape == (batch_size, 2)
1614

1715

1816
def test_sample(simulator, batch_size):

0 commit comments

Comments
 (0)