Skip to content

Commit 59cf297

Browse files
committed
Tests for OfflineEnsembleDataset
1 parent 1ca2a76 commit 59cf297

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

tests/test_datasets/conftest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def num_batches():
1313
return 4
1414

1515

16-
@pytest.fixture(params=["online_dataset", "offline_dataset"])
16+
@pytest.fixture(params=["online_dataset", "offline_dataset", "offline_ensemble_dataset"])
1717
def dataset(request, online_dataset, offline_dataset):
1818
return request.getfixturevalue(request.param)
1919

@@ -46,6 +46,25 @@ def offline_dataset(simulator, batch_size, num_batches, workers, use_multiproces
4646
)
4747

4848

49+
@pytest.fixture()
50+
def offline_ensemble_dataset(simulator, batch_size, num_batches, workers, use_multiprocessing):
51+
from bayesflow import OfflineEnsembleDataset
52+
53+
# TODO: there is a bug in keras where if len(dataset) == 1 batch
54+
# fit will error because no logs are generated
55+
# the single batch is then skipped entirely
56+
num_ensemble = 3
57+
data = simulator.sample((batch_size * num_batches * num_ensemble,))
58+
return OfflineEnsembleDataset(
59+
data=data,
60+
num_ensemble=num_ensemble,
61+
batch_size=batch_size,
62+
workers=workers,
63+
use_multiprocessing=use_multiprocessing,
64+
adapter=None,
65+
)
66+
67+
4968
@pytest.fixture()
5069
def online_dataset(simulator, batch_size, num_batches, workers, use_multiprocessing):
5170
from bayesflow import OnlineDataset

0 commit comments

Comments
 (0)