Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from collections.abc import Sequence, Mapping
from collections.abc import Sequence, Mapping, Callable

import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import median_abs_deviation

from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric
from bayesflow.utils.numpy_utils import credible_interval


def recovery(
estimates: Mapping[str, np.ndarray] | np.ndarray,
targets: Mapping[str, np.ndarray] | np.ndarray,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
point_agg=np.median,
uncertainty_agg=median_abs_deviation,
point_agg: Callable = np.median,
uncertainty_agg: Callable = credible_interval,
point_agg_kwargs: dict = None,
uncertainty_agg_kwargs: dict = None,
add_corr: bool = True,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
Expand Down Expand Up @@ -57,8 +58,17 @@ def recovery(
By default, select all keys.
variable_names : list or None, optional, default: None
The individual parameter names for nice plot titles. Inferred if None
point_agg : function to compute point estimates. Default: median
uncertainty_agg : function to compute uncertainty estimates. Default: MAD
point_agg : callable, optional, default: median
Function to compute point estimates.
uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95%
Function to compute a measure of uncertainty. Can either be the lower and upper
uncertainty bounds provided with the shape (2, num_datasets, num_params) or a
scalar measure of uncertainty (e.g., the median absolute deviation) with shape
(num_datasets, num_params).
point_agg_kwargs : Optional dictionary of further arguments passed to point_agg.
uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg.
For example, to change the coverage probability of credible_interval to 50%,
use uncertainty_agg_kwargs = dict(prob=0.5)
add_corr : boolean, default: True
Should correlations between estimates and ground truth values be shown?
figsize : tuple or None, optional, default : None
Expand Down Expand Up @@ -106,11 +116,18 @@ def recovery(
estimates = plot_data.pop("estimates")
targets = plot_data.pop("targets")

point_agg_kwargs = point_agg_kwargs or {}
uncertainty_agg_kwargs = uncertainty_agg_kwargs or {}

# Compute point estimates and uncertainties
point_estimate = point_agg(estimates, axis=1)
point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs)

if uncertainty_agg is not None:
u = uncertainty_agg(estimates, axis=1)
u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs)
if u.ndim == 3:
# compute lower and upper error
u[0, :, :] = point_estimate - u[0, :, :]
u[1, :, :] = u[1, :, :] - point_estimate

for i, ax in enumerate(plot_data["axes"].flat):
if i >= plot_data["num_variables"]:
Expand All @@ -121,7 +138,7 @@ def recovery(
_ = ax.errorbar(
targets[:, i],
point_estimate[:, i],
yerr=u[:, i],
yerr=u[..., i],
fmt="o",
alpha=0.5,
color=color,
Expand Down
45 changes: 45 additions & 0 deletions bayesflow/utils/numpy_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from scipy import special
from collections.abc import Sequence


def inverse_sigmoid(x: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd
with np.errstate(over="ignore"):
exp_beta_x = np.exp(beta * x)
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)


def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray:
"""
Compute credible interval from samples using quantiles.

Parameters
----------
x : array_like
Input array of samples from a posterior distribution or bootstrap samples.
prob : float, default 0.95
Coverage probability of the credible interval (between 0 and 1).
For example, 0.95 gives a 95% credible interval.
axis : Sequence[int]
Axis or axes along which the credible interval is computed.
Default is None (flatten array).

Returns
-------
a numpy array of shape (2, ...) with the first dimension indicating the
lower and upper bounds of the credible interval.

Examples
--------
>>> import numpy as np
>>> # Simulate posterior samples
>>> samples = np.random.normal(size=(10, 1000, 3))

>>> # Different coverage probabilities
>>> credible_interval(samples, prob=0.5, axis=1) # 50% CI
>>> credible_interval(samples, prob=0.99, axis=1) # 99% CI
"""

# Input validation
if not 0 <= prob <= 1:
raise ValueError(f"prob must be between 0 and 1, got {prob}")

# Calculate tail probabilities
alpha = 1 - prob
lower_q = alpha / 2
upper_q = 1 - alpha / 2

# Compute quantiles
return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs)
15 changes: 13 additions & 2 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,20 @@ def test_loss(history):
assert out.axes[0].title._text == "Loss Trajectory"


def test_recovery(random_estimates, random_targets):
def test_recovery_bounds(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4)
from bayesflow.utils.numpy_utils import credible_interval

out = bf.diagnostics.plots.recovery(
random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval
)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[2].title._text == "sigma"


def test_recovery_symmetric(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[2].title._text == "sigma"

Expand Down
Loading