Skip to content

Commit 08ed995

Browse files
paul-buerknerstefanradev93vpratz
authored
Use credible intervals for uncertainties in recovery plots (#573)
* use credible intervals for uncertainties in recovery plots * Add kwargs for agg functions * make point_arg_kwargs and uncertainty_agg_kwargs explicit arguments * adapt docs, minor stylistic changes * add support for symmetric uncertainty measures * add test for symmetric uncertainty measure in recovery --------- Co-authored-by: stefanradev93 <[email protected]> Co-authored-by: Valentin Pratz <[email protected]>
1 parent 4d8596b commit 08ed995

File tree

3 files changed

+85
-12
lines changed

3 files changed

+85
-12
lines changed

bayesflow/diagnostics/plots/recovery.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55

6-
from scipy.stats import median_abs_deviation
7-
86
from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric
7+
from bayesflow.utils.numpy_utils import credible_interval
98

109

1110
def recovery(
1211
estimates: Mapping[str, np.ndarray] | np.ndarray,
1312
targets: Mapping[str, np.ndarray] | np.ndarray,
1413
variable_keys: Sequence[str] = None,
1514
variable_names: Sequence[str] = None,
16-
point_agg=np.median,
17-
uncertainty_agg=median_abs_deviation,
15+
point_agg: Callable = np.median,
16+
uncertainty_agg: Callable = credible_interval,
17+
point_agg_kwargs: dict = None,
18+
uncertainty_agg_kwargs: dict = None,
1819
add_corr: bool = True,
1920
figsize: Sequence[int] = None,
2021
label_fontsize: int = 16,
@@ -57,8 +58,17 @@ def recovery(
5758
By default, select all keys.
5859
variable_names : list or None, optional, default: None
5960
The individual parameter names for nice plot titles. Inferred if None
60-
point_agg : function to compute point estimates. Default: median
61-
uncertainty_agg : function to compute uncertainty estimates. Default: MAD
61+
point_agg : callable, optional, default: median
62+
Function to compute point estimates.
63+
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
64+
Function to compute a measure of uncertainty. Can either be the lower and upper
65+
uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
66+
scalar measure of uncertainty (e.g., the median absolute deviation) with shape
67+
(num_datasets, num_params).
68+
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
69+
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
70+
For example, to change the coverage probability of credible_interval to 50%,
71+
use uncertainty_agg_kwargs = dict(prob=0.5)
6272
add_corr : boolean, default: True
6373
Should correlations between estimates and ground truth values be shown?
6474
figsize : tuple or None, optional, default : None
@@ -106,11 +116,18 @@ def recovery(
106116
estimates = plot_data.pop("estimates")
107117
targets = plot_data.pop("targets")
108118

119+
point_agg_kwargs = point_agg_kwargs or {}
120+
uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}
121+
109122
# Compute point estimates and uncertainties
110-
point_estimate = point_agg(estimates, axis=1)
123+
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)
111124

112125
if uncertainty_agg is not None:
113-
u = uncertainty_agg(estimates, axis=1)
126+
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
127+
if u.ndim == 3:
128+
# compute lower and upper error
129+
u[0, :, :] = point_estimate - u[0, :, :]
130+
u[1, :, :] = u[1, :, :] - point_estimate
114131

115132
for i, ax in enumerate(plot_data["axes"].flat):
116133
if i >= plot_data["num_variables"]:
@@ -121,7 +138,7 @@ def recovery(
121138
_ = ax.errorbar(
122139
targets[:, i],
123140
point_estimate[:, i],
124-
yerr=u[:, i],
141+
yerr=u[..., i],
125142
fmt="o",
126143
alpha=0.5,
127144
color=color,

bayesflow/utils/numpy_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy import special
3+
from collections.abc import Sequence
34

45

56
def inverse_sigmoid(x: np.ndarray) -> np.ndarray:
@@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd
4243
with np.errstate(over="ignore"):
4344
exp_beta_x = np.exp(beta * x)
4445
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)
46+
47+
48+
def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray:
49+
"""
50+
Compute credible interval from samples using quantiles.
51+
52+
Parameters
53+
----------
54+
x : array_like
55+
Input array of samples from a posterior distribution or bootstrap samples.
56+
prob : float, default 0.95
57+
Coverage probability of the credible interval (between 0 and 1).
58+
For example, 0.95 gives a 95% credible interval.
59+
axis : Sequence[int]
60+
Axis or axes along which the credible interval is computed.
61+
Default is None (flatten array).
62+
63+
Returns
64+
-------
65+
a numpy array of shape (2, ...) with the first dimension indicating the
66+
lower and upper bounds of the credible interval.
67+
68+
Examples
69+
--------
70+
>>> import numpy as np
71+
>>> # Simulate posterior samples
72+
>>> samples = np.random.normal(size=(10, 1000, 3))
73+
74+
>>> # Different coverage probabilities
75+
>>> credible_interval(samples, prob=0.5, axis=1) # 50% CI
76+
>>> credible_interval(samples, prob=0.99, axis=1) # 99% CI
77+
"""
78+
79+
# Input validation
80+
if not 0 <= prob <= 1:
81+
raise ValueError(f"prob must be between 0 and 1, got {prob}")
82+
83+
# Calculate tail probabilities
84+
alpha = 1 - prob
85+
lower_q = alpha / 2
86+
upper_q = 1 - alpha / 2
87+
88+
# Compute quantiles
89+
return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs)

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,20 @@ def test_loss(history):
9292
assert out.axes[0].title._text == "Loss Trajectory"
9393

9494

95-
def test_recovery(random_estimates, random_targets):
95+
def test_recovery_bounds(random_estimates, random_targets):
9696
# basic functionality: automatic variable names
97-
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4)
97+
from bayesflow.utils.numpy_utils import credible_interval
98+
99+
out = bf.diagnostics.plots.recovery(
100+
random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval
101+
)
102+
assert len(out.axes) == num_variables(random_estimates)
103+
assert out.axes[2].title._text == "sigma"
104+
105+
106+
def test_recovery_symmetric(random_estimates, random_targets):
107+
# basic functionality: automatic variable names
108+
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std)
98109
assert len(out.axes) == num_variables(random_estimates)
99110
assert out.axes[2].title._text == "sigma"
100111

0 commit comments

Comments
 (0)