Skip to content

Commit b43f1cc

Browse files
authored
Add utility function for batched simulations (#511)
The implementation is a simple wrapper leveraging the batching capabilities of `rejection_sample`.
1 parent 9018ce6 commit b43f1cc

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

bayesflow/simulators/simulator.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,32 @@ def rejection_sample(
6666
result = tree_concatenate([result, samples], axis=axis, numpy=True)
6767

6868
return result
69+
70+
@allow_batch_size
71+
def sample_batched(
72+
self,
73+
batch_shape: Shape,
74+
*,
75+
sample_size: int,
76+
**kwargs,
77+
):
78+
"""Sample the desired number of simulations in smaller batches.
79+
80+
Limited resources, especially memory, can make it necessary to run simulations in smaller batches.
81+
The number of samples per simulated batch is specified by `sample_size`.
82+
83+
Parameters
84+
----------
85+
batch_shape : Shape
86+
The desired output shape, as in :py:meth:`sample`. Will be rounded up to the next complete batch.
87+
sample_size : int
88+
The number of samples in each simulated batch.
89+
kwargs
90+
Additional keyword arguments passed to :py:meth:`sample`.
91+
92+
"""
93+
94+
def accept_all_predicate(x):
95+
return np.full((sample_size,), True)
96+
97+
return self.rejection_sample(batch_shape, predicate=accept_all_predicate, sample_size=sample_size, **kwargs)

tests/test_simulators/test_simulators.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ def test_sample(simulator, batch_size):
4141
assert not np.allclose(value, value[0])
4242

4343

44+
def test_sample_batched(simulator, batch_size):
45+
sample_size = 2
46+
samples = simulator.sample_batched((batch_size,), sample_size=sample_size)
47+
48+
# test output structure
49+
assert isinstance(samples, dict)
50+
51+
for key, value in samples.items():
52+
print(f"{key}.shape = {keras.ops.shape(value)}")
53+
54+
# test type
55+
assert isinstance(value, np.ndarray)
56+
57+
# test shape (sample_batched rounds up to complete batches)
58+
assert value.shape[0] == int(np.ceil(batch_size / sample_size)) * sample_size
59+
60+
# test batch randomness
61+
assert not np.allclose(value, value[0])
62+
63+
4464
def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
4565
samples = composite_gaussian.sample((batch_size,), n=fixed_n, mu=fixed_mu)
4666

0 commit comments

Comments
 (0)