Skip to content

Commit 3c93679

Browse files
committed
add test
1 parent a717600 commit 3c93679

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

tests/test_simulators/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,32 @@ def likelihood(mu, n):
167167
return make_simulator([prior, likelihood], meta_fn=context)
168168

169169

170+
@pytest.fixture(params=["drop", np.nan])
171+
def multimodel(request):
172+
from bayesflow.simulators import make_simulator, ModelComparisonSimulator
173+
174+
rng = np.random.default_rng()
175+
176+
def prior_1():
177+
return dict(w=rng.uniform())
178+
179+
def prior_2():
180+
return dict(c=rng.uniform())
181+
182+
def model_1(w):
183+
return dict(x=w)
184+
185+
def model_2(c):
186+
return dict(x=c)
187+
188+
simulator_1 = make_simulator([prior_1, model_1])
189+
simulator_2 = make_simulator([prior_2, model_2])
190+
191+
simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param)
192+
193+
return simulator
194+
195+
170196
@pytest.fixture()
171197
def fixed_n():
172198
return 5

tests/test_simulators/test_simulators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,13 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
4747
assert samples["mu"].shape == (batch_size, 1)
4848
assert np.all(samples["mu"] == fixed_mu)
4949
assert samples["y"].shape == (batch_size, fixed_n)
50+
51+
52+
def test_multimodel_sample(multimodel, batch_size):
53+
samples = multimodel.sample(batch_size)
54+
55+
if multimodel.key_conflicts == "drop":
56+
assert set(samples) == {"x", "model_indices"}
57+
else:
58+
assert set(samples) == {"x", "model_indices", "c", "w"}
59+
assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size

0 commit comments

Comments
 (0)