Skip to content

Commit bae0fdb

Browse files
committed
add tests for overriding simulators
1 parent 01e56ea commit bae0fdb

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

tests/test_simulators/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,35 @@ def two_moons_simulator(request):
148148
return request.getfixturevalue(request.param)
149149

150150

151+
@pytest.fixture()
152+
def composite_gaussian():
153+
from bayesflow.simulators import make_simulator
154+
155+
def context():
156+
n = np.random.randint(10, 100)
157+
return dict(n=n)
158+
159+
def prior():
160+
mu = np.random.normal(0, 1)
161+
return dict(mu=mu)
162+
163+
def likelihood(mu, n):
164+
y = np.random.normal(mu, 1, n)
165+
return dict(y=y)
166+
167+
return make_simulator([prior, likelihood], meta_fn=context)
168+
169+
170+
@pytest.fixture()
171+
def fixed_n():
172+
return 5
173+
174+
175+
@pytest.fixture()
176+
def fixed_mu():
177+
return 100
178+
179+
151180
@pytest.fixture(
152181
params=[
153182
"bernoulli_glm",

tests/test_simulators/test_simulators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@ def test_sample(simulator, batch_size):
3838

3939
# test batch randomness
4040
assert not np.allclose(value, value[0])
41+
42+
43+
def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
44+
samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu)
45+
46+
assert samples["n"] == fixed_n
47+
assert samples["mu"].shape == (batch_size, 1)
48+
assert np.all(samples["mu"] == fixed_mu)
49+
assert samples["y"].shape == (batch_size, fixed_n)

0 commit comments

Comments
 (0)