Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion bayesflow/simulators/sequential_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class SequentialSimulator(Simulator):
"""Combines multiple simulators into one, sequentially."""

def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True, replace_inputs: bool = True):
"""
Initialize a SequentialSimulator.

Expand All @@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True)
expand_outputs : bool, optional
If True, 1D output arrays are expanded with an additional dimension at the end.
Default is True.
replace_inputs : bool, optional
If True, **kwargs are auto-batched and replace simulator outputs.
"""

self.simulators = simulators
self.expand_outputs = expand_outputs
self.replace_inputs = replace_inputs

@allow_batch_size
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
for simulator in self.simulators:
data |= simulator.sample(batch_shape, **(kwargs | data))

if self.replace_inputs:
common_keys = set(data.keys()) & set(kwargs.keys())
for key in common_keys:
value = kwargs.pop(key)
if isinstance(data[key], np.ndarray):
value = np.broadcast_to(value, data[key].shape)
data[key] = value

if self.expand_outputs:
data = {
key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()
Expand Down
29 changes: 29 additions & 0 deletions tests/test_simulators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,35 @@ def two_moons_simulator(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def composite_gaussian():
from bayesflow.simulators import make_simulator

def context():
n = np.random.randint(10, 100)
return dict(n=n)

def prior():
mu = np.random.normal(0, 1)
return dict(mu=mu)

def likelihood(mu, n):
y = np.random.normal(mu, 1, n)
return dict(y=y)

return make_simulator([prior, likelihood], meta_fn=context)


@pytest.fixture()
def fixed_n():
return 5


@pytest.fixture()
def fixed_mu():
return 100


@pytest.fixture(
params=[
"bernoulli_glm",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_simulators/test_simulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ def test_sample(simulator, batch_size):

# test batch randomness
assert not np.allclose(value, value[0])


def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu)

assert samples["n"] == fixed_n
assert samples["mu"].shape == (batch_size, 1)
assert np.all(samples["mu"] == fixed_mu)
assert samples["y"].shape == (batch_size, fixed_n)
Loading