diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 462f06546..6f30e1bdd 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -1,11 +1,10 @@ -from collections.abc import Sequence, Mapping +from collections.abc import Sequence, Mapping, Callable import matplotlib.pyplot as plt import numpy as np -from scipy.stats import median_abs_deviation - from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric +from bayesflow.utils.numpy_utils import credible_interval def recovery( @@ -13,8 +12,10 @@ def recovery( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, - point_agg=np.median, - uncertainty_agg=median_abs_deviation, + point_agg: Callable = np.median, + uncertainty_agg: Callable = credible_interval, + point_agg_kwargs: dict = None, + uncertainty_agg_kwargs: dict = None, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, @@ -57,8 +58,17 @@ def recovery( By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None - point_agg : function to compute point estimates. Default: median - uncertainty_agg : function to compute uncertainty estimates. Default: MAD + point_agg : callable, optional, default: median + Function to compute point estimates. + uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95% + Function to compute a measure of uncertainty. Can either be the lower and upper + uncertainty bounds provided with the shape (2, num_datasets, num_params) or a + scalar measure of uncertainty (e.g., the median absolute deviation) with shape + (num_datasets, num_params). + point_agg_kwargs : Optional dictionary of further arguments passed to point_agg. + uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg. + For example, to change the coverage probability of credible_interval to 50%, + use uncertainty_agg_kwargs = dict(prob=0.5) add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None @@ -106,11 +116,18 @@ def recovery( estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") + point_agg_kwargs = point_agg_kwargs or {} + uncertainty_agg_kwargs = uncertainty_agg_kwargs or {} + # Compute point estimates and uncertainties - point_estimate = point_agg(estimates, axis=1) + point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs) if uncertainty_agg is not None: - u = uncertainty_agg(estimates, axis=1) + u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs) + if u.ndim == 3: + # compute lower and upper error + u[0, :, :] = point_estimate - u[0, :, :] + u[1, :, :] = u[1, :, :] - point_estimate for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: @@ -121,7 +138,7 @@ def recovery( _ = ax.errorbar( targets[:, i], point_estimate[:, i], - yerr=u[:, i], + yerr=u[..., i], fmt="o", alpha=0.5, color=color, diff --git a/bayesflow/utils/numpy_utils.py b/bayesflow/utils/numpy_utils.py index a9f3c5502..ba79b42ff 100644 --- a/bayesflow/utils/numpy_utils.py +++ b/bayesflow/utils/numpy_utils.py @@ -1,5 +1,6 @@ import numpy as np from scipy import special +from collections.abc import Sequence def inverse_sigmoid(x: np.ndarray) -> np.ndarray: @@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd with np.errstate(over="ignore"): exp_beta_x = np.exp(beta * x) return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta) + + +def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray: + """ + Compute credible interval from samples using quantiles. + + Parameters + ---------- + x : array_like + Input array of samples from a posterior distribution or bootstrap samples. + prob : float, default 0.95 + Coverage probability of the credible interval (between 0 and 1). + For example, 0.95 gives a 95% credible interval. + axis : Sequence[int] + Axis or axes along which the credible interval is computed. + Default is None (flatten array). + + Returns + ------- + a numpy array of shape (2, ...) with the first dimension indicating the + lower and upper bounds of the credible interval. + + Examples + -------- + >>> import numpy as np + >>> # Simulate posterior samples + >>> samples = np.random.normal(size=(10, 1000, 3)) + + >>> # Different coverage probabilities + >>> credible_interval(samples, prob=0.5, axis=1) # 50% CI + >>> credible_interval(samples, prob=0.99, axis=1) # 99% CI + """ + + # Input validation + if not 0 <= prob <= 1: + raise ValueError(f"prob must be between 0 and 1, got {prob}") + + # Calculate tail probabilities + alpha = 1 - prob + lower_q = alpha / 2 + upper_q = 1 - alpha / 2 + + # Compute quantiles + return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs) diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index e2f1c09f2..6f449787e 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -92,9 +92,20 @@ def test_loss(history): assert out.axes[0].title._text == "Loss Trajectory" -def test_recovery(random_estimates, random_targets): +def test_recovery_bounds(random_estimates, random_targets): # basic functionality: automatic variable names - out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4) + from bayesflow.utils.numpy_utils import credible_interval + + out = bf.diagnostics.plots.recovery( + random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval + ) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[2].title._text == "sigma" + + +def test_recovery_symmetric(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std) assert len(out.axes) == num_variables(random_estimates) assert out.axes[2].title._text == "sigma"