Skip to content

Commit 45c43f2

Browse files
committed
Improve docs for simulators
1 parent 2bedf47 commit 45c43f2

File tree

5 files changed

+141
-1
lines changed

5 files changed

+141
-1
lines changed

bayesflow/simulators/hierarchical_simulator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,38 @@
1010

1111
class HierarchicalSimulator(Simulator):
1212
def __init__(self, hierarchy: Sequence[Simulator]):
13+
"""
14+
Initialize the hierarchical simulator with a sequence of simulators.
15+
16+
Parameters
17+
----------
18+
hierarchy : Sequence[Simulator]
19+
A sequence of simulator instances representing each level of the hierarchy.
20+
Each level's output is used as input for the next, with increasing batch dimensions.
21+
"""
1322
self.hierarchy = hierarchy
1423

1524
@allow_batch_size
1625
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
26+
"""
27+
Sample from a hierarchy of simulators.
28+
29+
Parameters
30+
----------
31+
batch_shape : Shape
32+
A tuple where each element specifies the number of samples at the corresponding level
33+
of the hierarchy. The total batch size increases multiplicatively through the levels.
34+
**kwargs
35+
Additional keyword arguments passed to each simulator. These are combined with outputs
36+
from previous levels and repeated appropriately.
37+
38+
Returns
39+
-------
40+
output_data : dict of str to np.ndarray
41+
A dictionary containing the outputs from the entire hierarchy. Outputs are reshaped to
42+
match the hierarchical batch shape, i.e., with shape equal to `batch_shape + original_shape`.
43+
"""
44+
1745
input_data = {}
1846
output_data = {}
1947

bayesflow/simulators/lambda_simulator.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Callable, Sequence, Mapping
2+
13
import numpy as np
24

35
from bayesflow.utils import batched_call, filter_kwargs, tree_stack
@@ -10,12 +12,44 @@
1012
class LambdaSimulator(Simulator):
1113
"""Implements a simulator based on a sampling function."""
1214

13-
def __init__(self, sample_fn: callable, *, is_batched: bool = False):
15+
def __init__(self, sample_fn: Callable[Sequence[int, Mapping[str, any]]], *, is_batched: bool = False):
16+
"""
17+
Initialize a simulator based on a simple callable function
18+
19+
Parameters
20+
----------
21+
sample_fn : Callable[Sequence[int, Mapping[str, any]]]
22+
A function that generates samples. It should accept `batch_shape` as its first argument
23+
(if `is_batched=True`), followed by keyword arguments.
24+
is_batched : bool, optional
25+
Whether the `sample_fn` is implemented to handle batched sampling directly.
26+
If False, `sample_fn` will be called once per sample and results will be stacked.
27+
Default is False.
28+
"""
1429
self.sample_fn = sample_fn
1530
self.is_batched = is_batched
1631

1732
@allow_batch_size
1833
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
34+
"""
35+
Sample using the wrapped sampling function.
36+
37+
Parameters
38+
----------
39+
batch_shape : Shape
40+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
41+
but an int can also be passed.
42+
**kwargs
43+
Additional keyword arguments passed to the sampling function. Only valid arguments
44+
(as determined by the function's signature) are used.
45+
46+
Returns
47+
-------
48+
data : dict of str to np.ndarray
49+
A dictionary of sampled outputs. Keys are output names and values are numpy arrays.
50+
If `is_batched` is False, individual outputs are stacked along the first axis.
51+
"""
52+
1953
# try to use only valid keyword-arguments
2054
kwargs = filter_kwargs(kwargs, self.sample_fn)
2155

bayesflow/simulators/make_simulator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
@singledispatch
1212
def make_simulator(arg, *_, **__):
13+
"""
14+
This is a dispatch function that will accept a list of simulators (callables) returning
15+
dictionaries with simulated outputs. The outputs of simulators will be passed to following
16+
simulators if the latter accept keyword arguments associated with the keys of previous outputs.
17+
"""
1318
raise TypeError(f"Cannot infer simulator from {arg!r}.")
1419

1520

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ def __init__(
2424
use_mixed_batches: bool = True,
2525
shared_simulator: Simulator | FunctionType = None,
2626
):
27+
"""
28+
Initialize a multi-model simulator that can generate data for mixture / model comparison problems.
29+
30+
Parameters
31+
----------
32+
simulators : Sequence[Simulator]
33+
A sequence of simulator instances, each representing a different model.
34+
p : Sequence[float], optional
35+
A sequence of probabilities associated with each simulator. Must sum to 1.
36+
Mutually exclusive with `logits`.
37+
logits : Sequence[float], optional
38+
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
39+
If neither `p` nor `logits` is provided, defaults to uniform logits.
40+
use_mixed_batches : bool, optional
41+
If True, samples in a batch are drawn from different models. If False, the entire batch
42+
is drawn from a single model chosen according to the model probabilities. Default is True.
43+
shared_simulator : Simulator or FunctionType, optional
44+
A shared simulator whose outputs are passed to all model simulators. If a function is
45+
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
46+
"""
2747
self.simulators = simulators
2848

2949
if isinstance(shared_simulator, FunctionType):
@@ -51,6 +71,26 @@ def __init__(
5171

5272
@allow_batch_size
5373
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
74+
"""
75+
Sample from the model comparison simulator.
76+
77+
Parameters
78+
----------
79+
batch_shape : Shape
80+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
81+
but can also be an int.
82+
**kwargs
83+
Additional keyword arguments passed to each simulator. These may include outputs from
84+
the shared simulator.
85+
86+
Returns
87+
-------
88+
data : dict of str to np.ndarray
89+
A dictionary containing the sampled outputs. Includes:
90+
- outputs from the selected simulator(s)
91+
- optionally, outputs from the shared simulator
92+
- "model_indices": a one-hot encoded array indicating the model origin of each sample
93+
"""
5494
data = {}
5595
if self.shared_simulator:
5696
data |= self.shared_simulator.sample(batch_shape, **kwargs)

bayesflow/simulators/sequential_simulator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,44 @@ class SequentialSimulator(Simulator):
1111
"""Combines multiple simulators into one, sequentially."""
1212

1313
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
14+
"""
15+
Initialize a SequentialSimulator.
16+
17+
Parameters
18+
----------
19+
simulators : Sequence[Simulator]
20+
A sequence of simulator instances to be executed sequentially. Each simulator should
21+
return dictionary outputs and may depend on outputs from previous simulators.
22+
expand_outputs : bool, optional
23+
If True, 1D output arrays are expanded with an additional dimension at the end.
24+
Default is True.
25+
"""
26+
1427
self.simulators = simulators
1528
self.expand_outputs = expand_outputs
1629

1730
@allow_batch_size
1831
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
32+
"""
33+
Sample sequentially from the internal simulator.
34+
35+
Parameters
36+
----------
37+
batch_shape : Shape
38+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
39+
but it also accepts an int.
40+
**kwargs
41+
Additional keyword arguments passed to each simulator. These may include previously
42+
sampled outputs used as inputs for subsequent simulators.
43+
44+
Returns
45+
-------
46+
data : dict of str to np.ndarray
47+
A dictionary containing the combined outputs from all simulators. Keys are output names
48+
and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
49+
have shape (..., 1).
50+
"""
51+
1952
data = {}
2053
for simulator in self.simulators:
2154
data |= simulator.sample(batch_shape, **(kwargs | data))

0 commit comments

Comments
 (0)