Skip to content

Commit 09df093

Browse files
committed
Merge branch 'dev' into fix_sampling_method_kwargs
2 parents ca7f3bd + 08ed995 commit 09df093

File tree

5 files changed

+100
-26
lines changed

5 files changed

+100
-26
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55

6-
from scipy.stats import median_abs_deviation
7-
86
from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric
7+
from bayesflow.utils.numpy_utils import credible_interval
98

109

1110
def recovery(
1211
estimates: Mapping[str, np.ndarray] | np.ndarray,
1312
targets: Mapping[str, np.ndarray] | np.ndarray,
1413
variable_keys: Sequence[str] = None,
1514
variable_names: Sequence[str] = None,
16-
point_agg=np.median,
17-
uncertainty_agg=median_abs_deviation,
15+
point_agg: Callable = np.median,
16+
uncertainty_agg: Callable = credible_interval,
17+
point_agg_kwargs: dict = None,
18+
uncertainty_agg_kwargs: dict = None,
1819
add_corr: bool = True,
1920
figsize: Sequence[int] = None,
2021
label_fontsize: int = 16,
@@ -57,8 +58,17 @@ def recovery(
5758
By default, select all keys.
5859
variable_names : list or None, optional, default: None
5960
The individual parameter names for nice plot titles. Inferred if None
60-
point_agg : function to compute point estimates. Default: median
61-
uncertainty_agg : function to compute uncertainty estimates. Default: MAD
61+
point_agg : callable, optional, default: median
62+
Function to compute point estimates.
63+
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+
Function to compute a measure of uncertainty. Can either be the lower and upper
65+
uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+
scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+
(num_datasets, num_params).
68+
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
69+
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
70+
For example, to change the coverage probability of credible_interval to 50%,
71+
use uncertainty_agg_kwargs = dict(prob=0.5)
6272
add_corr : boolean, default: True
6373
Should correlations between estimates and ground truth values be shown?
6474
figsize : tuple or None, optional, default : None
@@ -106,11 +116,18 @@ def recovery(
106116
estimates = plot_data.pop("estimates")
107117
targets = plot_data.pop("targets")
108118

119+
point_agg_kwargs = point_agg_kwargs or {}
120+
uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
121+
109122
# Compute point estimates and uncertainties
110-
point_estimate = point_agg(estimates, axis=1)
123+
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)
111124

112125
if uncertainty_agg is not None:
113-
u = uncertainty_agg(estimates, axis=1)
126+
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
127+
if u.ndim == 3:
128+
# compute lower and upper error
129+
u[0, :, :] = point_estimate - u[0, :, :]
130+
u[1, :, :] = u[1, :, :] - point_estimate
114131

115132
for i, ax in enumerate(plot_data["axes"].flat):
116133
if i >= plot_data["num_variables"]:
@@ -121,7 +138,7 @@ def recovery(
121138
_ = ax.errorbar(
122139
targets[:, i],
123140
point_estimate[:, i],
124-
yerr=u[:, i],
141+
yerr=u[..., i],
125142
fmt="o",
126143
alpha=0.5,
127144
color=color,

bayesflow/distributions/diagonal_student_t.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,19 @@ def __init__(
6363

6464
self.seed_generator = seed_generator or keras.random.SeedGenerator()
6565

66-
self.dim = None
66+
self.dims = None
6767
self._loc = None
6868
self._scale = None
6969

7070
def build(self, input_shape: Shape) -> None:
7171
if self.built:
7272
return
7373

74-
self.dim = int(input_shape[-1])
74+
self.dims = tuple(input_shape[1:])
7575

7676
# convert to tensor and broadcast if necessary
77-
self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32")
78-
self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32")
77+
self.loc = ops.cast(ops.broadcast_to(self.loc, self.dims), "float32")
78+
self.scale = ops.cast(ops.broadcast_to(self.scale, self.dims), "float32")
7979

8080
if self.trainable_parameters:
8181
self._loc = self.add_weight(
@@ -96,14 +96,14 @@ def build(self, input_shape: Shape) -> None:
9696

9797
def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9898
mahalanobis_term = ops.sum((samples - self._loc) ** 2 / self._scale**2, axis=-1)
99-
result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df)
99+
result = -0.5 * (self.df + sum(self.dims)) * ops.log1p(mahalanobis_term / self.df)
100100

101101
if normalize:
102102
log_normalization_constant = (
103-
-0.5 * self.dim * math.log(self.df)
104-
- 0.5 * self.dim * math.log(math.pi)
103+
-0.5 * sum(self.dims) * math.log(self.df)
104+
- 0.5 * sum(self.dims) * math.log(math.pi)
105105
- math.lgamma(0.5 * self.df)
106-
+ math.lgamma(0.5 * (self.df + self.dim))
106+
+ math.lgamma(0.5 * (self.df + sum(self.dims)))
107107
- ops.sum(keras.ops.log(self._scale))
108108
)
109109
result += log_normalization_constant
@@ -119,9 +119,10 @@ def sample(self, batch_shape: Shape) -> Tensor:
119119

120120
# The chi-quare samples need to be repeated across self.dim
121121
# since for each element of batch_shape only one sample is created.
122-
chi2_samples = expand_tile(chi2_samples, n=self.dim, axis=-1)
122+
chi2_samples = expand_tile(chi2_samples, n=sum(self.dims), axis=-1)
123+
chi2_samples = keras.ops.reshape(chi2_samples, batch_shape + self.dims)
123124

124-
normal_samples = keras.random.normal(batch_shape + (self.dim,), seed=self.seed_generator)
125+
normal_samples = keras.random.normal(batch_shape + self.dims, seed=self.seed_generator)
125126

126127
return self._loc + self._scale * normal_samples * ops.sqrt(self.df / chi2_samples)
127128

bayesflow/distributions/mixture.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(
5959

6060
self.trainable_mixture = trainable_mixture
6161

62-
self.dim = None
62+
self.dims = None
6363
self._mixture_logits = None
6464

6565
@allow_batch_size
@@ -78,7 +78,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
7878
Returns
7979
-------
8080
samples: Tensor
81-
A tensor of shape `batch_shape + (dim,)` containing samples drawn
81+
A tensor of shape `batch_shape + dims` containing samples drawn
8282
from the mixture.
8383
"""
8484
# Will use numpy until keras adds support for N-D categorical sampling
@@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape) -> Tensor:
8787
cat_samples = cat_samples.argmax(axis=-1)
8888

8989
# Prepare array to fill and dtype to infer
90-
samples = np.zeros(batch_shape + (self.dim,))
90+
samples = np.zeros(batch_shape + self.dims)
9191
dtype = None
9292

9393
# Fill in array with vectorized sampling per component
@@ -137,7 +137,7 @@ def build(self, input_shape: Shape) -> None:
137137
if self.built:
138138
return
139139

140-
self.dim = input_shape[-1]
140+
self.dims = tuple(input_shape[1:])
141141

142142
for distribution in self.distributions:
143143
distribution.build(input_shape)

bayesflow/utils/numpy_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy import special
3+
from collections.abc import Sequence
34

45

56
def inverse_sigmoid(x: np.ndarray) -> np.ndarray:
@@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd
4243
with np.errstate(over="ignore"):
4344
exp_beta_x = np.exp(beta * x)
4445
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)
46+
47+
48+
def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray:
49+
"""
50+
Compute credible interval from samples using quantiles.
51+
52+
Parameters
53+
----------
54+
x : array_like
55+
Input array of samples from a posterior distribution or bootstrap samples.
56+
prob : float, default 0.95
57+
Coverage probability of the credible interval (between 0 and 1).
58+
For example, 0.95 gives a 95% credible interval.
59+
axis : Sequence[int]
60+
Axis or axes along which the credible interval is computed.
61+
Default is None (flatten array).
62+
63+
Returns
64+
-------
65+
a numpy array of shape (2, ...) with the first dimension indicating the
66+
lower and upper bounds of the credible interval.
67+
68+
Examples
69+
--------
70+
>>> import numpy as np
71+
>>> # Simulate posterior samples
72+
>>> samples = np.random.normal(size=(10, 1000, 3))
73+
74+
>>> # Different coverage probabilities
75+
>>> credible_interval(samples, prob=0.5, axis=1) # 50% CI
76+
>>> credible_interval(samples, prob=0.99, axis=1) # 99% CI
77+
"""
78+
79+
# Input validation
80+
if not 0 <= prob <= 1:
81+
raise ValueError(f"prob must be between 0 and 1, got {prob}")
82+
83+
# Calculate tail probabilities
84+
alpha = 1 - prob
85+
lower_q = alpha / 2
86+
upper_q = 1 - alpha / 2
87+
88+
# Compute quantiles
89+
return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs)

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,20 @@ def test_loss(history):
9292
assert out.axes[0].title._text == "Loss Trajectory"
9393

9494

95-
def test_recovery(random_estimates, random_targets):
95+
def test_recovery_bounds(random_estimates, random_targets):
9696
# basic functionality: automatic variable names
97-
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4)
97+
from bayesflow.utils.numpy_utils import credible_interval
98+
99+
out = bf.diagnostics.plots.recovery(
100+
random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval
101+
)
102+
assert len(out.axes) == num_variables(random_estimates)
103+
assert out.axes[2].title._text == "sigma"
104+
105+
106+
def test_recovery_symmetric(random_estimates, random_targets):
107+
# basic functionality: automatic variable names
108+
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std)
98109
assert len(out.axes) == num_variables(random_estimates)
99110
assert out.axes[2].title._text == "sigma"
100111

0 commit comments

Comments
 (0)