From a4432a20561b30923f55f9f300c0def917718f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Thu, 11 Sep 2025 13:55:29 +0200 Subject: [PATCH 1/6] use credible intervals for uncertainties in recovery plots --- bayesflow/diagnostics/plots/recovery.py | 16 +++++---- bayesflow/utils/numpy_utils.py | 45 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 462f06546..45931ae5f 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -3,9 +3,8 @@ import matplotlib.pyplot as plt import numpy as np -from scipy.stats import median_abs_deviation - from bayesflow.utils import prepare_plot_data, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric +from bayesflow.utils.numpy_utils import credible_interval def recovery( @@ -14,7 +13,8 @@ def recovery( variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, point_agg=np.median, - uncertainty_agg=median_abs_deviation, + uncertainty_agg=credible_interval, + prob=0.95, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, @@ -58,7 +58,8 @@ def recovery( variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None point_agg : function to compute point estimates. Default: median - uncertainty_agg : function to compute uncertainty estimates. Default: MAD + uncertainty_agg : function to compute uncertainty interval bounds. + Default: credible_interval add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None @@ -110,7 +111,10 @@ def recovery( point_estimate = point_agg(estimates, axis=1) if uncertainty_agg is not None: - u = uncertainty_agg(estimates, axis=1) + u = uncertainty_agg(estimates, prob=prob, axis=1) + # compute lower and upper error + u[0, :, :] = point_estimate - u[0, :, :] + u[1, :, :] = u[1, :, :] - point_estimate for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: @@ -121,7 +125,7 @@ def recovery( _ = ax.errorbar( targets[:, i], point_estimate[:, i], - yerr=u[:, i], + yerr=u[:, :, i], fmt="o", alpha=0.5, color=color, diff --git a/bayesflow/utils/numpy_utils.py b/bayesflow/utils/numpy_utils.py index a9f3c5502..28ab2f738 100644 --- a/bayesflow/utils/numpy_utils.py +++ b/bayesflow/utils/numpy_utils.py @@ -1,5 +1,6 @@ import numpy as np from scipy import special +from collections.abc import Sequence def inverse_sigmoid(x: np.ndarray) -> np.ndarray: @@ -42,3 +43,47 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd with np.errstate(over="ignore"): exp_beta_x = np.exp(beta * x) return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta) + + +def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] = None, **kwargs) -> np.ndarray: + """ + Compute credible interval from samples using quantiles. + + Parameters + ---------- + x : array_like + Input array of samples from a posterior distribution or bootstrap samples. + prob : float, default 0.95 + Coverage probability of the credible interval (between 0 and 1). + For example, 0.95 gives a 95% credible interval. + axis : Sequence[int] + Axis or axes along which the credible interval is computed. + Default is None (flatten array). + + Returns + ------- + a numpy array of shape (2, ...) with the first dimension indicating the + lower and upper bounds of the credible interval. + + Examples + -------- + >>> import numpy as np + >>> # Simulate posterior samples + >>> samples = np.random.normal(10, 1000, 3) + + >>> # Different coverage probabilities + >>> credible_interval(samples, prob=0.5, axis=1) # 50% CI + >>> credible_interval(samples, prob=0.99, axis=1) # 99% CI + """ + + # Input validation + if not 0 <= prob <= 1: + raise ValueError(f"prob must be between 0 and 1, got {prob}") + + # Calculate tail probabilities + alpha = 1 - prob + lower_q = alpha / 2 + upper_q = 1 - alpha / 2 + + # Compute quantiles + return np.quantile(x, q=(lower_q, upper_q), axis=axis, **kwargs) From 89007cec03c8b57aab1a476c51a96e5d02dcc3d1 Mon Sep 17 00:00:00 2001 From: stefanradev93 Date: Thu, 11 Sep 2025 09:11:50 -0400 Subject: [PATCH 2/6] Add kwargs for agg functions --- bayesflow/diagnostics/plots/recovery.py | 5 ++--- bayesflow/utils/numpy_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 45931ae5f..bbeac3b38 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -14,7 +14,6 @@ def recovery( variable_names: Sequence[str] = None, point_agg=np.median, uncertainty_agg=credible_interval, - prob=0.95, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, @@ -108,10 +107,10 @@ def recovery( targets = plot_data.pop("targets") # Compute point estimates and uncertainties - point_estimate = point_agg(estimates, axis=1) + point_estimate = point_agg(estimates, axis=1, **kwargs.get("point_agg_kwargs", {})) if uncertainty_agg is not None: - u = uncertainty_agg(estimates, prob=prob, axis=1) + u = uncertainty_agg(estimates, axis=1, **kwargs.get("uncertainty_agg_kwargs", {})) # compute lower and upper error u[0, :, :] = point_estimate - u[0, :, :] u[1, :, :] = u[1, :, :] - point_estimate diff --git a/bayesflow/utils/numpy_utils.py b/bayesflow/utils/numpy_utils.py index 28ab2f738..ba79b42ff 100644 --- a/bayesflow/utils/numpy_utils.py +++ b/bayesflow/utils/numpy_utils.py @@ -45,7 +45,7 @@ def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.nd return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta) -def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] = None, **kwargs) -> np.ndarray: +def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] | int = None, **kwargs) -> np.ndarray: """ Compute credible interval from samples using quantiles. @@ -69,7 +69,7 @@ def credible_interval(x: np.ndarray, prob: float = 0.95, axis: Sequence[int] = N -------- >>> import numpy as np >>> # Simulate posterior samples - >>> samples = np.random.normal(10, 1000, 3) + >>> samples = np.random.normal(size=(10, 1000, 3)) >>> # Different coverage probabilities >>> credible_interval(samples, prob=0.5, axis=1) # 50% CI From 0b6caeacfe8da235b487e006deb79836bdde5e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 12 Sep 2025 09:32:36 +0200 Subject: [PATCH 3/6] make point_arg_kwargs and uncertainty_agg_kwargs explicit arguments --- bayesflow/diagnostics/plots/recovery.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index bbeac3b38..7a723511e 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -14,6 +14,8 @@ def recovery( variable_names: Sequence[str] = None, point_agg=np.median, uncertainty_agg=credible_interval, + point_agg_kwargs=None, + uncertainty_agg_kwargs=None, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, @@ -58,7 +60,11 @@ def recovery( The individual parameter names for nice plot titles. Inferred if None point_agg : function to compute point estimates. Default: median uncertainty_agg : function to compute uncertainty interval bounds. - Default: credible_interval + Default: credible_interval with coverage probability 95%. + point_agg_kwargs : Optional dictionary of further arguments passed to point_agg. + uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg. + For example, to change the coverage probability of credible_interval to 50%, + use uncertainty_agg_kwargs = dict(prob = 0.5) add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None @@ -106,11 +112,17 @@ def recovery( estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") + if point_agg_kwargs is None: + point_agg_kwargs = {} + + if uncertainty_agg_kwargs is None: + uncertainty_agg_kwargs = {} + # Compute point estimates and uncertainties - point_estimate = point_agg(estimates, axis=1, **kwargs.get("point_agg_kwargs", {})) + point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs) if uncertainty_agg is not None: - u = uncertainty_agg(estimates, axis=1, **kwargs.get("uncertainty_agg_kwargs", {})) + u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs) # compute lower and upper error u[0, :, :] = point_estimate - u[0, :, :] u[1, :, :] = u[1, :, :] - point_estimate From cc537410970c9832a78dda85496fa5fda596fa2e Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 12 Sep 2025 19:00:00 +0000 Subject: [PATCH 4/6] adapt docs, minor stylistic changes --- bayesflow/diagnostics/plots/recovery.py | 26 ++++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index 7a723511e..ad95c83f5 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence, Mapping +from collections.abc import Sequence, Mapping, Callable import matplotlib.pyplot as plt import numpy as np @@ -12,10 +12,10 @@ def recovery( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, - point_agg=np.median, - uncertainty_agg=credible_interval, - point_agg_kwargs=None, - uncertainty_agg_kwargs=None, + point_agg: Callable = np.median, + uncertainty_agg: Callable = credible_interval, + point_agg_kwargs: dict = None, + uncertainty_agg_kwargs: dict = None, add_corr: bool = True, figsize: Sequence[int] = None, label_fontsize: int = 16, @@ -58,13 +58,14 @@ def recovery( By default, select all keys. variable_names : list or None, optional, default: None The individual parameter names for nice plot titles. Inferred if None - point_agg : function to compute point estimates. Default: median - uncertainty_agg : function to compute uncertainty interval bounds. - Default: credible_interval with coverage probability 95%. + point_agg : callable, optional, default: median + Function to compute point estimates. + uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95% + Function to compute uncertainty interval bounds. point_agg_kwargs : Optional dictionary of further arguments passed to point_agg. uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg. For example, to change the coverage probability of credible_interval to 50%, - use uncertainty_agg_kwargs = dict(prob = 0.5) + use uncertainty_agg_kwargs = dict(prob=0.5) add_corr : boolean, default: True Should correlations between estimates and ground truth values be shown? figsize : tuple or None, optional, default : None @@ -112,11 +113,8 @@ def recovery( estimates = plot_data.pop("estimates") targets = plot_data.pop("targets") - if point_agg_kwargs is None: - point_agg_kwargs = {} - - if uncertainty_agg_kwargs is None: - uncertainty_agg_kwargs = {} + point_agg_kwargs = point_agg_kwargs or {} + uncertainty_agg_kwargs = uncertainty_agg_kwargs or {} # Compute point estimates and uncertainties point_estimate = point_agg(estimates, axis=1, **point_agg_kwargs) From 217ea69228a3eb9b3277fc38e41481ca37147145 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 12 Sep 2025 19:27:04 +0000 Subject: [PATCH 5/6] add support for symmetric uncertainty measures --- bayesflow/diagnostics/plots/recovery.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index ad95c83f5..6f30e1bdd 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -61,7 +61,10 @@ def recovery( point_agg : callable, optional, default: median Function to compute point estimates. uncertainty_agg : callable, optional, default: credible_interval with coverage probability 95% - Function to compute uncertainty interval bounds. + Function to compute a measure of uncertainty. Can either be the lower and upper + uncertainty bounds provided with the shape (2, num_datasets, num_params) or a + scalar measure of uncertainty (e.g., the median absolute deviation) with shape + (num_datasets, num_params). point_agg_kwargs : Optional dictionary of further arguments passed to point_agg. uncertainty_agg_kwargs : Optional dictionary of further arguments passed to uncertainty_agg. For example, to change the coverage probability of credible_interval to 50%, @@ -121,9 +124,10 @@ def recovery( if uncertainty_agg is not None: u = uncertainty_agg(estimates, axis=1, **uncertainty_agg_kwargs) - # compute lower and upper error - u[0, :, :] = point_estimate - u[0, :, :] - u[1, :, :] = u[1, :, :] - point_estimate + if u.ndim == 3: + # compute lower and upper error + u[0, :, :] = point_estimate - u[0, :, :] + u[1, :, :] = u[1, :, :] - point_estimate for i, ax in enumerate(plot_data["axes"].flat): if i >= plot_data["num_variables"]: @@ -134,7 +138,7 @@ def recovery( _ = ax.errorbar( targets[:, i], point_estimate[:, i], - yerr=u[:, :, i], + yerr=u[..., i], fmt="o", alpha=0.5, color=color, From d6f57412f06e011c388b9e0118e346d7b8383857 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 12 Sep 2025 19:29:27 +0000 Subject: [PATCH 6/6] add test for symmetric uncertainty measure in recovery --- tests/test_diagnostics/test_diagnostics_plots.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index e2f1c09f2..6f449787e 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -92,9 +92,20 @@ def test_loss(history): assert out.axes[0].title._text == "Loss Trajectory" -def test_recovery(random_estimates, random_targets): +def test_recovery_bounds(random_estimates, random_targets): # basic functionality: automatic variable names - out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4) + from bayesflow.utils.numpy_utils import credible_interval + + out = bf.diagnostics.plots.recovery( + random_estimates, random_targets, markersize=4, uncertainty_agg=credible_interval + ) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[2].title._text == "sigma" + + +def test_recovery_symmetric(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4, uncertainty_agg=np.std) assert len(out.axes) == num_variables(random_estimates) assert out.axes[2].title._text == "sigma"