Skip to content

Commit 8482926

Browse files
authored
overriding simulator outputs with auto-batched inputs (#420)
* overriding simulator outputs with auto-batched inputs * add tests for overriding simulators
1 parent dc5ee17 commit 8482926

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

bayesflow/simulators/sequential_simulator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class SequentialSimulator(Simulator):
1111
"""Combines multiple simulators into one, sequentially."""
1212

13-
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
13+
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True, replace_inputs: bool = True):
1414
"""
1515
Initialize a SequentialSimulator.
1616
@@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True)
2222
expand_outputs : bool, optional
2323
If True, 1D output arrays are expanded with an additional dimension at the end.
2424
Default is True.
25+
replace_inputs : bool, optional
26+
If True, **kwargs are auto-batched and replace simulator outputs.
2527
"""
2628

2729
self.simulators = simulators
2830
self.expand_outputs = expand_outputs
31+
self.replace_inputs = replace_inputs
2932

3033
@allow_batch_size
3134
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
5356
for simulator in self.simulators:
5457
data |= simulator.sample(batch_shape, **(kwargs | data))
5558

59+
if self.replace_inputs:
60+
common_keys = set(data.keys()) & set(kwargs.keys())
61+
for key in common_keys:
62+
value = kwargs.pop(key)
63+
if isinstance(data[key], np.ndarray):
64+
value = np.broadcast_to(value, data[key].shape)
65+
data[key] = value
66+
5667
if self.expand_outputs:
5768
data = {
5869
key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()

tests/test_simulators/conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,35 @@ def two_moons_simulator(request):
148148
return request.getfixturevalue(request.param)
149149

150150

151+
@pytest.fixture()
152+
def composite_gaussian():
153+
from bayesflow.simulators import make_simulator
154+
155+
def context():
156+
n = np.random.randint(10, 100)
157+
return dict(n=n)
158+
159+
def prior():
160+
mu = np.random.normal(0, 1)
161+
return dict(mu=mu)
162+
163+
def likelihood(mu, n):
164+
y = np.random.normal(mu, 1, n)
165+
return dict(y=y)
166+
167+
return make_simulator([prior, likelihood], meta_fn=context)
168+
169+
170+
@pytest.fixture()
171+
def fixed_n():
172+
return 5
173+
174+
175+
@pytest.fixture()
176+
def fixed_mu():
177+
return 100
178+
179+
151180
@pytest.fixture(
152181
params=[
153182
"bernoulli_glm",

tests/test_simulators/test_simulators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@ def test_sample(simulator, batch_size):
3838

3939
# test batch randomness
4040
assert not np.allclose(value, value[0])
41+
42+
43+
def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
44+
samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu)
45+
46+
assert samples["n"] == fixed_n
47+
assert samples["mu"].shape == (batch_size, 1)
48+
assert np.all(samples["mu"] == fixed_mu)
49+
assert samples["y"].shape == (batch_size, fixed_n)

0 commit comments

Comments
 (0)