Skip to content

Commit 922412f

Browse files
committed
Merge branch 'fix_sampling_method_kwargs' into compositional_sampling_diffusion
2 parents b2ef755 + ea0659d commit 922412f

File tree

8 files changed

+193
-29
lines changed

8 files changed

+193
-29
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55

6-
from scipy.stats import median_abs_deviation
7-
86
from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric
7+
from bayesflow.utils.numpy_utils import credible_interval
98

109

1110
def recovery(
1211
estimates: Mapping[str, np.ndarray] | np.ndarray,
1312
targets: Mapping[str, np.ndarray] | np.ndarray,
1413
variable_keys: Sequence[str] = None,
1514
variable_names: Sequence[str] = None,
16-
point_agg=np.median,
17-
uncertainty_agg=median_abs_deviation,
15+
point_agg: Callable = np.median,
16+
uncertainty_agg: Callable = credible_interval,
17+
point_agg_kwargs: dict = None,
18+
uncertainty_agg_kwargs: dict = None,
1819
add_corr: bool = True,
1920
figsize: Sequence[int] = None,
2021
label_fontsize: int = 16,
@@ -57,8 +58,17 @@ def recovery(
5758
By default, select all keys.
5859
variable_names : list or None, optional, default: None
5960
The individual parameter names for nice plot titles. Inferred if None
60-
point_agg : function to compute point estimates. Default: median
61-
uncertainty_agg : function to compute uncertainty estimates. Default: MAD
61+
point_agg : callable, optional, default: median
62+
Function to compute point estimates.
63+
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+
Function to compute a measure of uncertainty. Can either be the lower and upper
65+
uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+
scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+
(num_datasets, num_params).
68+
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
69+
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
70+
For example, to change the coverage probability of credible_interval to 50%,
71+
use uncertainty_agg_kwargs = dict(prob=0.5)
6272
add_corr : boolean, default: True
6373
Should correlations between estimates and ground truth values be shown?
6474
figsize : tuple or None, optional, default : None
@@ -106,11 +116,18 @@ def recovery(
106116
estimates = plot_data.pop("estimates")
107117
targets = plot_data.pop("targets")
108118

119+
point_agg_kwargs = point_agg_kwargs or {}
120+
uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
121+
109122
# Compute point estimates and uncertainties
110-
point_estimate = point_agg(estimates, axis=1)
123+
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)
111124

112125
if uncertainty_agg is not None:
113-
u = uncertainty_agg(estimates, axis=1)
126+
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
127+
if u.ndim == 3:
128+
# compute lower and upper error
129+
u[0, :, :] = point_estimate - u[0, :, :]
130+
u[1, :, :] = u[1, :, :] - point_estimate
114131

115132
for i, ax in enumerate(plot_data["axes"].flat):
116133
if i >= plot_data["num_variables"]:
@@ -121,7 +138,7 @@ def recovery(
121138
_ = ax.errorbar(
122139
targets[:, i],
123140
point_estimate[:, i],
124-
yerr=u[:, i],
141+
yerr=u[..., i],
125142
fmt="o",
126143
alpha=0.5,
127144
color=color,

bayesflow/distributions/diagonal_student_t.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,19 @@ def __init__(
6363

6464
self.seed_generator = seed_generator or keras.random.SeedGenerator()
6565

66-
self.dim = None
66+
self.dims = None
6767
self._loc = None
6868
self._scale = None
6969

7070
def build(self, input_shape: Shape) -> None:
7171
if self.built:
7272
return
7373

74-
self.dim = int(input_shape[-1])
74+
self.dims = tuple(input_shape[1:])
7575

7676
# convert to tensor and broadcast if necessary
77-
self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32")
78-
self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32")
77+
self.loc = ops.cast(ops.broadcast_to(self.loc, self.dims), "float32")
78+
self.scale = ops.cast(ops.broadcast_to(self.scale, self.dims), "float32")
7979

8080
if self.trainable_parameters:
8181
self._loc = self.add_weight(
@@ -96,14 +96,14 @@ def build(self, input_shape: Shape) -> None:
9696

9797
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9898
mahalanobis_term = ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1)
99-
result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df)
99+
result = -0.5 * (self.df + sum(self.dims)) * ops.log1p(mahalanobis_term / self.df)
100100

101101
if normalize:
102102
log_normalization_constant = (
103-
-0.5 * self.dim * math.log(self.df)
104-
- 0.5 * self.dim * math.log(math.pi)
103+
-0.5 * sum(self.dims) * math.log(self.df)
104+
- 0.5 * sum(self.dims) * math.log(math.pi)
105105
- math.lgamma(0.5 * self.df)
106-
+ math.lgamma(0.5 * (self.df + self.dim))
106+
+ math.lgamma(0.5 * (self.df + sum(self.dims)))
107107
- ops.sum(keras.ops.log(self._scale))
108108
)
109109
result += log_normalization_constant
@@ -119,9 +119,10 @@ def sample(self, batch_shape: Shape) -> Tensor:
119119

120120
# The chi-quare samples need to be repeated across self.dim
121121
# since for each element of batch_shape only one sample is created.
122-
chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1)
122+
chi2_samples = expand_tile(chi2_samples, n=sum(self.dims), axis=-1)
123+
chi2_samples = keras.ops.reshape(chi2_samples, batch_shape + self.dims)
123124

124-
normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)
125+
normal_samples = keras.random.normal(batch_shape + self.dims, seed=self.seed_generator)
125126

126127
return self._loc + self._scale * normal_samples * ops.sqrt(self.df / chi2_samples)
127128

bayesflow/distributions/mixture.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959

6060
self.trainable_mixture = trainable_mixture
6161

62-
self.dim = None
62+
self.dims = None
6363
self._mixture_logits = None
6464

6565
@allow_batch_size
@@ -78,7 +78,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
7878
Returns
7979
-------
8080
samples: Tensor
81-
A tensor of shape `batch_shape + (dim,)` containing samples drawn
81+
A tensor of shape `batch_shape + dims` containing samples drawn
8282
from the mixture.
8383
"""
8484
# Will use numpy until keras adds support for N-D categorical sampling
@@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
8787
cat_samples = cat_samples.argmax(axis=-1)
8888

8989
# Prepare array to fill and dtype to infer
90-
samples = np.zeros(batch_shape + (self.dim,))
90+
samples = np.zeros(batch_shape + self.dims)
9191
dtype = None
9292

9393
# Fill in array with vectorized sampling per component
@@ -137,7 +137,7 @@ def build(self, input_shape: Shape) -> None:
137137
if self.built:
138138
return
139139

140-
self.dim = input_shape[-1]
140+
self.dims = tuple(input_shape[1:])
141141

142142
for distribution in self.distributions:
143143
distribution.build(input_shape)

bayesflow/networks/transformers/mab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
6-
from bayesflow.utils import layer_kwargs
6+
from bayesflow.utils import layer_kwargs, filter_kwargs
77
from bayesflow.utils.decorators import sanitize_input_shape
88
from bayesflow.utils.serialization import serializable
99

@@ -111,7 +111,7 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
111111
"""
112112

113113
h = self.input_projector(seq_x) + self.attention(
114-
query=seq_x, key=seq_y, value=seq_y, training=training, **kwargs
114+
query=seq_x, key=seq_y, value=seq_y, training=training, **filter_kwargs(kwargs, self.attention.call)
115115
)
116116
if self.ln_pre is not None:
117117
h = self.ln_pre(h, training=training)

bayesflow/networks/transformers/set_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
147147
out : Tensor
148148
Output of shape (batch_size, set_size, output_dim)
149149
"""
150-
summary = self.attention_blocks(input_set, training=training, **kwargs)
150+
summary = self.attention_blocks(input_set, training=training)
151151
summary = self.pooling_by_attention(summary, training=training, **kwargs)
152152
summary = self.output_projector(summary)
153153
return summary

bayesflow/utils/numpy_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy import special
3+
from collections.abc import Sequence
34

45

56
def inverse_sigmoid(x: np.ndarray) -> np.ndarray:
@@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd
4243
with np.errstate(over="ignore"):
4344
exp_beta_x = np.exp(beta * x)
4445
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)
46+
47+
48+
def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray:
49+
"""
50+
Compute credible interval from samples using quantiles.
51+
52+
Parameters
53+
----------
54+
x : array_like
55+
Input array of samples from a posterior distribution or bootstrap samples.
56+
prob : float, default 0.95
57+
Coverage probability of the credible interval (between 0 and 1).
58+
For example, 0.95 gives a 95% credible interval.
59+
axis : Sequence[int]
60+
Axis or axes along which the credible interval is computed.
61+
Default is None (flatten array).
62+
63+
Returns
64+
-------
65+
a numpy array of shape (2, ...) with the first dimension indicating the
66+
lower and upper bounds of the credible interval.
67+
68+
Examples
69+
--------
70+
>>> import numpy as np
71+
>>> # Simulate posterior samples
72+
>>> samples = np.random.normal(size=(10, 1000, 3))
73+
74+
>>> # Different coverage probabilities
75+
>>> credible_interval(samples, prob=0.5, axis=1) # 50% CI
76+
>>> credible_interval(samples, prob=0.99, axis=1) # 99% CI
77+
"""
78+
79+
# Input validation
80+
if not 0 <= prob <= 1:
81+
raise ValueError(f"prob must be between 0 and 1, got {prob}")
82+
83+
# Calculate tail probabilities
84+
alpha = 1 - prob
85+
lower_q = alpha / 2
86+
upper_q = 1 - alpha / 2
87+
88+
# Compute quantiles
89+
return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs)

tests/test_approximators/test_sample.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import keras
23
from tests.utils import check_combination_simulator_adapter
34

@@ -16,3 +17,92 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
1617
samples = approximator.sample(num_samples=2, conditions=data)
1718

1819
assert isinstance(samples, dict)
20+
21+
22+
@pytest.mark.parametrize("inference_network_type", ["flow_matching", "diffusion_model"])
23+
@pytest.mark.parametrize("summary_network_type", ["none", "deep_set", "set_transformer", "time_series"])
24+
@pytest.mark.parametrize("method", ["euler", "rk45", "euler_maruyama"])
25+
def test_approximator_sample_with_integration_methods(
26+
inference_network_type, summary_network_type, method, simulator, adapter
27+
):
28+
"""Test approximator sampling with different integration methods and summary networks.
29+
30+
Tests flow matching and diffusion models with different ODE/SDE solvers:
31+
- euler, rk45: Available for both flow matching and diffusion models
32+
- euler_maruyama: Only for diffusion models (stochastic)
33+
34+
Also tests with different summary network types.
35+
"""
36+
batch_size = 8 # Use smaller batch size for faster tests
37+
check_combination_simulator_adapter(simulator, adapter)
38+
39+
# Skip euler_maruyama for flow matching (deterministic model)
40+
if inference_network_type == "flow_matching" and method == "euler_maruyama":
41+
pytest.skip("euler_maruyama is only available for diffusion models")
42+
43+
# Create inference network based on type
44+
if inference_network_type == "flow_matching":
45+
from bayesflow.networks import FlowMatching, MLP
46+
47+
inference_network = FlowMatching(
48+
subnet=MLP(widths=[32, 32]),
49+
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
50+
)
51+
elif inference_network_type == "diffusion_model":
52+
from bayesflow.networks import DiffusionModel, MLP
53+
54+
inference_network = DiffusionModel(
55+
subnet=MLP(widths=[32, 32]),
56+
integrate_kwargs={"steps": 10}, # Use fewer steps for faster tests
57+
)
58+
else:
59+
pytest.skip(f"Unsupported inference network type: {inference_network_type}")
60+
61+
# Create summary network based on type
62+
summary_network = None
63+
if summary_network_type != "none":
64+
if summary_network_type == "deep_set":
65+
from bayesflow.networks import DeepSet, MLP
66+
67+
summary_network = DeepSet(subnet=MLP(widths=[16, 16]))
68+
elif summary_network_type == "set_transformer":
69+
from bayesflow.networks import SetTransformer
70+
71+
summary_network = SetTransformer(embed_dims=[16, 16], mlp_widths=[16, 16])
72+
elif summary_network_type == "time_series":
73+
from bayesflow.networks import TimeSeriesNetwork
74+
75+
summary_network = TimeSeriesNetwork(subnet_kwargs={"widths": [16, 16]}, cell_type="lstm")
76+
else:
77+
pytest.skip(f"Unsupported summary network type: {summary_network_type}")
78+
79+
# Update adapter to include summary variables if summary network is present
80+
from bayesflow import ContinuousApproximator
81+
82+
adapter = ContinuousApproximator.build_adapter(
83+
inference_variables=["mean", "std"],
84+
summary_variables=["x"], # Use x as summary variable for testing
85+
)
86+
87+
# Create approximator
88+
from bayesflow import ContinuousApproximator
89+
90+
approximator = ContinuousApproximator(
91+
adapter=adapter, inference_network=inference_network, summary_network=summary_network
92+
)
93+
94+
# Generate test data
95+
num_batches = 2 # Use fewer batches for faster tests
96+
data = simulator.sample((num_batches * batch_size,))
97+
98+
# Build approximator
99+
batch = adapter(data)
100+
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
101+
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
102+
approximator.build(batch_shapes)
103+
104+
# Test sampling with the specified method
105+
samples = approximator.sample(num_samples=2, conditions=data, method=method)
106+
107+
# Verify results
108+
assert isinstance(samples, dict)

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,20 @@ def test_loss(history):
9292
assert out.axes[0].title._text == "Loss Trajectory"
9393

9494

95-
def test_recovery(random_estimates, random_targets):
95+
def test_recovery_bounds(random_estimates, random_targets):
9696
# basic functionality: automatic variable names
97-
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4)
97+
from bayesflow.utils.numpy_utils import credible_interval
98+
99+
out = bf.diagnostics.plots.recovery(
100+
random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval
101+
)
102+
assert len(out.axes) == num_variables(random_estimates)
103+
assert out.axes[2].title._text == "sigma"
104+
105+
106+
def test_recovery_symmetric(random_estimates, random_targets):
107+
# basic functionality: automatic variable names
108+
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std)
98109
assert len(out.axes) == num_variables(random_estimates)
99110
assert out.axes[2].title._text == "sigma"
100111

0 commit comments

Comments
 (0)