diff --git a/bayesflow/simulators/sequential_simulator.py b/bayesflow/simulators/sequential_simulator.py index a3939f650..d1b71f43a 100644 --- a/bayesflow/simulators/sequential_simulator.py +++ b/bayesflow/simulators/sequential_simulator.py @@ -10,7 +10,7 @@ class SequentialSimulator(Simulator): """Combines multiple simulators into one, sequentially.""" - def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True): + def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True, replace_inputs: bool = True): """ Initialize a SequentialSimulator. @@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True) expand_outputs : bool, optional If True, 1D output arrays are expanded with an additional dimension at the end. Default is True. + replace_inputs : bool, optional + If True, **kwargs are auto-batched and replace simulator outputs. """ self.simulators = simulators self.expand_outputs = expand_outputs + self.replace_inputs = replace_inputs @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: @@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: for simulator in self.simulators: data |= simulator.sample(batch_shape, **(kwargs | data)) + if self.replace_inputs: + common_keys = set(data.keys()) & set(kwargs.keys()) + for key in common_keys: + value = kwargs.pop(key) + if isinstance(data[key], np.ndarray): + value = np.broadcast_to(value, data[key].shape) + data[key] = value + if self.expand_outputs: data = { key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items() diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 29d7eaf15..0e76a5396 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -148,6 +148,35 @@ def two_moons_simulator(request): return request.getfixturevalue(request.param) +@pytest.fixture() +def composite_gaussian(): + from bayesflow.simulators import make_simulator + + def context(): + n = np.random.randint(10, 100) + return dict(n=n) + + def prior(): + mu = np.random.normal(0, 1) + return dict(mu=mu) + + def likelihood(mu, n): + y = np.random.normal(mu, 1, n) + return dict(y=y) + + return make_simulator([prior, likelihood], meta_fn=context) + + +@pytest.fixture() +def fixed_n(): + return 5 + + +@pytest.fixture() +def fixed_mu(): + return 100 + + @pytest.fixture( params=[ "bernoulli_glm", diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index f95d3fef8..e9a3c80c0 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -38,3 +38,12 @@ def test_sample(simulator, batch_size): # test batch randomness assert not np.allclose(value, value[0]) + + +def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu): + samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu) + + assert samples["n"] == fixed_n + assert samples["mu"].shape == (batch_size, 1) + assert np.all(samples["mu"] == fixed_mu) + assert samples["y"].shape == (batch_size, fixed_n)