diff --git a/bayesflow/simulators/simulator.py b/bayesflow/simulators/simulator.py index 0a008d655..00d3d84f3 100644 --- a/bayesflow/simulators/simulator.py +++ b/bayesflow/simulators/simulator.py @@ -66,3 +66,32 @@ def rejection_sample( result = tree_concatenate([result, samples], axis=axis, numpy=True) return result + + @allow_batch_size + def sample_batched( + self, + batch_shape: Shape, + *, + sample_size: int, + **kwargs, + ): + """Sample the desired number of simulations in smaller batches. + + Limited resources, especially memory, can make it necessary to run simulations in smaller batches. + The number of samples per simulated batch is specified by `sample_size`. + + Parameters + ---------- + batch_shape : Shape + The desired output shape, as in :py:meth:`sample`. Will be rounded up to the next complete batch. + sample_size : int + The number of samples in each simulated batch. + kwargs + Additional keyword arguments passed to :py:meth:`sample`. + + """ + + def accept_all_predicate(x): + return np.full((sample_size,), True) + + return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs) diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index f1996c82e..da064bea2 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -41,6 +41,26 @@ def test_sample(simulator, batch_size): assert not np.allclose(value, value[0]) +def test_sample_batched(simulator, batch_size): + sample_size = 2 + samples = simulator.sample_batched((batch_size,), sample_size=sample_size) + + # test output structure + assert isinstance(samples, dict) + + for key, value in samples.items(): + print(f"{key}.shape = {keras.ops.shape(value)}") + + # test type + assert isinstance(value, np.ndarray) + + # test shape (sample_batched rounds up to complete batches) + assert value.shape[0] == int(np.ceil(batch_size / sample_size)) * sample_size + + # test batch randomness + assert not np.allclose(value, value[0]) + + def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu): samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu)