Skip to content

Commit be274c1

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into feat-integration-schedule
2 parents 1c6b90d + 18df01c commit be274c1

File tree

11 files changed

+795
-336
lines changed

11 files changed

+795
-336
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -184,7 +185,7 @@ def calibration_ecdf(
184185

185186
for ax, title in zip(plot_data["axes"].flat, titles):
186187
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
187-
ax.legend(fontsize=legend_fontsize)
188+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
188189
ax.set_title(title, fontsize=title_fontsize)
189190

190191
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def calibration_ecdf_from_quantiles(
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22+
legend_location: str = "upper right",
2223
title_fontsize: int = 18,
2324
tick_fontsize: int = 12,
2425
rank_ecdf_color: str = "#132a70",
@@ -173,7 +174,7 @@ def calibration_ecdf_from_quantiles(
173174
alpha=0.2,
174175
label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands" + "\n(pointwise)",
175176
)
176-
ax.legend(fontsize=legend_fontsize)
177+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
177178
ax.set_title(title, fontsize=title_fontsize)
178179

179180
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/simulators/hierarchical_simulator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,38 @@
1010

1111
class HierarchicalSimulator(Simulator):
1212
def __init__(self, hierarchy: Sequence[Simulator]):
13+
"""
14+
Initialize the hierarchical simulator with a sequence of simulators.
15+
16+
Parameters
17+
----------
18+
hierarchy : Sequence[Simulator]
19+
A sequence of simulator instances representing each level of the hierarchy.
20+
Each level's output is used as input for the next, with increasing batch dimensions.
21+
"""
1322
self.hierarchy = hierarchy
1423

1524
@allow_batch_size
1625
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
26+
"""
27+
Sample from a hierarchy of simulators.
28+
29+
Parameters
30+
----------
31+
batch_shape : Shape
32+
A tuple where each element specifies the number of samples at the corresponding level
33+
of the hierarchy. The total batch size increases multiplicatively through the levels.
34+
**kwargs
35+
Additional keyword arguments passed to each simulator. These are combined with outputs
36+
from previous levels and repeated appropriately.
37+
38+
Returns
39+
-------
40+
output_data : dict of str to np.ndarray
41+
A dictionary containing the outputs from the entire hierarchy. Outputs are reshaped to
42+
match the hierarchical batch shape, i.e., with shape equal to `batch_shape + original_shape`.
43+
"""
44+
1745
input_data = {}
1846
output_data = {}
1947

bayesflow/simulators/lambda_simulator.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Callable, Sequence, Mapping
2+
13
import numpy as np
24

35
from bayesflow.utils import batched_call, filter_kwargs, tree_stack
@@ -10,12 +12,44 @@
1012
class LambdaSimulator(Simulator):
1113
"""Implements a simulator based on a sampling function."""
1214

13-
def __init__(self, sample_fn: callable, *, is_batched: bool = False):
15+
def __init__(self, sample_fn: Callable[[Sequence[int]], Mapping[str, any]], *, is_batched: bool = False):
16+
"""
17+
Initialize a simulator based on a simple callable function
18+
19+
Parameters
20+
----------
21+
sample_fn : Callable[[Sequence[int]], Mapping[str, any]]
22+
A function that generates samples. It should accept `batch_shape` as its first argument
23+
(if `is_batched=True`), followed by keyword arguments.
24+
is_batched : bool, optional
25+
Whether the `sample_fn` is implemented to handle batched sampling directly.
26+
If False, `sample_fn` will be called once per sample and results will be stacked.
27+
Default is False.
28+
"""
1429
self.sample_fn = sample_fn
1530
self.is_batched = is_batched
1631

1732
@allow_batch_size
1833
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
34+
"""
35+
Sample using the wrapped sampling function.
36+
37+
Parameters
38+
----------
39+
batch_shape : Shape
40+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
41+
but an int can also be passed.
42+
**kwargs
43+
Additional keyword arguments passed to the sampling function. Only valid arguments
44+
(as determined by the function's signature) are used.
45+
46+
Returns
47+
-------
48+
data : dict of str to np.ndarray
49+
A dictionary of sampled outputs. Keys are output names and values are numpy arrays.
50+
If `is_batched` is False, individual outputs are stacked along the first axis.
51+
"""
52+
1953
# try to use only valid keyword-arguments
2054
kwargs = filter_kwargs(kwargs, self.sample_fn)
2155

bayesflow/simulators/make_simulator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
@singledispatch
1212
def make_simulator(arg, *_, **__):
13+
"""
14+
This is a dispatch function that will accept a list of simulators (callables) returning
15+
dictionaries with simulated outputs. The outputs of simulators will be passed to following
16+
simulators if the latter accept keyword arguments associated with the keys of previous outputs.
17+
"""
1318
raise TypeError(f"Cannot infer simulator from {arg!r}.")
1419

1520

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ def __init__(
2424
use_mixed_batches: bool = True,
2525
shared_simulator: Simulator | FunctionType = None,
2626
):
27+
"""
28+
Initialize a multi-model simulator that can generate data for mixture / model comparison problems.
29+
30+
Parameters
31+
----------
32+
simulators : Sequence[Simulator]
33+
A sequence of simulator instances, each representing a different model.
34+
p : Sequence[float], optional
35+
A sequence of probabilities associated with each simulator. Must sum to 1.
36+
Mutually exclusive with `logits`.
37+
logits : Sequence[float], optional
38+
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
39+
If neither `p` nor `logits` is provided, defaults to uniform logits.
40+
use_mixed_batches : bool, optional
41+
If True, samples in a batch are drawn from different models. If False, the entire batch
42+
is drawn from a single model chosen according to the model probabilities. Default is True.
43+
shared_simulator : Simulator or FunctionType, optional
44+
A shared simulator whose outputs are passed to all model simulators. If a function is
45+
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
46+
"""
2747
self.simulators = simulators
2848

2949
if isinstance(shared_simulator, FunctionType):
@@ -51,34 +71,50 @@ def __init__(
5171

5272
@allow_batch_size
5373
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
74+
"""
75+
Sample from the model comparison simulator.
76+
77+
Parameters
78+
----------
79+
batch_shape : Shape
80+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
81+
but the user can also supply an int.
82+
**kwargs
83+
Additional keyword arguments passed to each simulator. These may include outputs from
84+
the shared simulator.
85+
86+
Returns
87+
-------
88+
data : dict of str to np.ndarray
89+
A dictionary containing the sampled outputs. Includes:
90+
- outputs from the selected simulator(s)
91+
- optionally, outputs from the shared simulator
92+
- "model_indices": a one-hot encoded array indicating the model origin of each sample
93+
"""
5494
data = {}
5595
if self.shared_simulator:
5696
data |= self.shared_simulator.sample(batch_shape, **kwargs)
5797

58-
if not self.use_mixed_batches:
59-
# draw one model index for the whole batch (faster)
60-
model_index = np.random.choice(len(self.simulators), p=npu.softmax(self.logits))
98+
softmax_logits = npu.softmax(self.logits)
99+
num_models = len(self.simulators)
61100

62-
simulator = self.simulators[model_index]
63-
data = simulator.sample(batch_shape, **(kwargs | data))
64-
65-
model_indices = np.full(batch_shape, model_index, dtype="int32")
66-
model_indices = npu.one_hot(model_indices, len(self.simulators))
67-
else:
68-
# generate data randomly from each model (slower)
69-
model_counts = np.random.multinomial(n=batch_shape[0], pvals=npu.softmax(self.logits))
70-
71-
sims = []
72-
for n, simulator in zip(model_counts, self.simulators):
73-
if n == 0:
74-
continue
75-
sim = simulator.sample(n, **(kwargs | data))
76-
sims.append(sim)
101+
# generate data randomly from each model (slower)
102+
if self.use_mixed_batches:
103+
model_counts = np.random.multinomial(n=batch_shape[0], pvals=softmax_logits)
77104

105+
sims = [
106+
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
107+
]
78108
sims = tree_concatenate(sims, numpy=True)
79109
data |= sims
80110

81-
model_indices = np.eye(len(self.simulators), dtype="int32")
82-
model_indices = np.repeat(model_indices, model_counts, axis=0)
111+
model_indices = np.repeat(np.eye(num_models, dtype="int32"), model_counts, axis=0)
112+
113+
# draw one model index for the whole batch (faster)
114+
else:
115+
model_index = np.random.choice(num_models, p=softmax_logits)
116+
117+
data = self.simulators[model_index].sample(batch_shape, **(kwargs | data))
118+
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)
83119

84120
return data | {"model_indices": model_indices}

bayesflow/simulators/sequential_simulator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,44 @@ class SequentialSimulator(Simulator):
1111
"""Combines multiple simulators into one, sequentially."""
1212

1313
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
14+
"""
15+
Initialize a SequentialSimulator.
16+
17+
Parameters
18+
----------
19+
simulators : Sequence[Simulator]
20+
A sequence of simulator instances to be executed sequentially. Each simulator should
21+
return dictionary outputs and may depend on outputs from previous simulators.
22+
expand_outputs : bool, optional
23+
If True, 1D output arrays are expanded with an additional dimension at the end.
24+
Default is True.
25+
"""
26+
1427
self.simulators = simulators
1528
self.expand_outputs = expand_outputs
1629

1730
@allow_batch_size
1831
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
32+
"""
33+
Sample sequentially from the internal simulator.
34+
35+
Parameters
36+
----------
37+
batch_shape : Shape
38+
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
39+
but it also accepts an int.
40+
**kwargs
41+
Additional keyword arguments passed to each simulator. These may include previously
42+
sampled outputs used as inputs for subsequent simulators.
43+
44+
Returns
45+
-------
46+
data : dict of str to np.ndarray
47+
A dictionary containing the combined outputs from all simulators. Keys are output names
48+
and values are sampled arrays. If `expand_outputs` is True, 1D arrays are expanded to
49+
have shape (..., 1).
50+
"""
51+
1952
data = {}
2053
for simulator in self.simulators:
2154
data |= simulator.sample(batch_shape, **(kwargs | data))

bayesflow/wrappers/mamba/mamba_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555

5656
super().__init__(**keras_kwargs(kwargs))
5757

58-
# if keras.backend.backend() != "torch":
59-
# raise RuntimeError("Mamba is only available using torch backend.")
58+
if keras.backend.backend() != "torch":
59+
raise RuntimeError("Mamba is only available using torch backend.")
6060

6161
try:
6262
from mamba_ssm import Mamba

0 commit comments

Comments
 (0)