diff --git a/bayesflow/diagnostics/plots/__init__.py b/bayesflow/diagnostics/plots/__init__.py index fe260aa7e..689e26013 100644 --- a/bayesflow/diagnostics/plots/__init__.py +++ b/bayesflow/diagnostics/plots/__init__.py @@ -1,6 +1,7 @@ from .calibration_ecdf import calibration_ecdf from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles from .calibration_histogram import calibration_histogram +from .coverage import coverage from .loss import loss from .mc_calibration import mc_calibration from .mc_confusion_matrix import mc_confusion_matrix diff --git a/bayesflow/diagnostics/plots/coverage.py b/bayesflow/diagnostics/plots/coverage.py new file mode 100644 index 000000000..f7ced2de5 --- /dev/null +++ b/bayesflow/diagnostics/plots/coverage.py @@ -0,0 +1,183 @@ +from collections.abc import Sequence, Mapping + +import matplotlib.pyplot as plt +import numpy as np + +from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage + + +def coverage( + estimates: Mapping[str, np.ndarray] | np.ndarray, + targets: Mapping[str, np.ndarray] | np.ndarray, + difference: bool = False, + variable_keys: Sequence[str] = None, + variable_names: Sequence[str] = None, + figsize: Sequence[int] = None, + label_fontsize: int = 16, + title_fontsize: int = 18, + tick_fontsize: int = 12, + color: str = "#132a70", + num_col: int = None, + num_row: int = None, +) -> plt.Figure: + """ + Creates coverage plots showing empirical coverage of posterior credible intervals. + + The empirical coverage shows the coverage (proportion of true variable values that fall within the interval) + of the central posterior credible intervals. + A well-calibrated model would have coverage exactly match interval width (i.e. 95% + credible interval contains the true value 95% of the time) as shown by the diagonal line. + + The coverage is accompanied by credible intervals for the coverage (gray ribbon). + These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior. + For more details on the Beta-Binomial model, see Chapter 2 of Bayesian Data Analysis (2013, 3rd ed.) by + Gelman A., Carlin J., Stern H., Dunson D., Vehtari A., & Rubin D. + + Parameters + ---------- + estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params) + The posterior draws obtained from num_datasets + targets : np.ndarray of shape (num_datasets, num_params) + The true parameter values used for generating num_datasets + difference : bool, optional, default: False + If True, plots the difference between empirical coverage and ideal coverage + (coverage - width), making deviations from ideal calibration more visible. + If False, plots the standard coverage plot. + variable_keys : list or None, optional, default: None + Select keys from the dictionaries provided in estimates and targets. + By default, select all keys. + variable_names : list or None, optional, default: None + The parameter names for nice plot titles. Inferred if None + figsize : tuple or None, optional, default: None + The figure size passed to the matplotlib constructor. Inferred if None. + label_fontsize : int, optional, default: 16 + The font size of the y-label and x-label text + title_fontsize : int, optional, default: 18 + The font size of the title text + tick_fontsize : int, optional, default: 12 + The font size of the axis ticklabels + color : str, optional, default: '#132a70' + The color for the coverage line + num_row : int, optional, default: None + The number of rows for the subplots. Dynamically determined if None. + num_col : int, optional, default: None + The number of columns for the subplots. Dynamically determined if None. + + Returns + ------- + f : plt.Figure - the figure instance for optional saving + + Raises + ------ + ShapeError + If there is a deviation from the expected shapes of ``estimates`` and ``targets``. + + """ + + # Gather plot data and metadata into a dictionary + plot_data = prepare_plot_data( + estimates=estimates, + targets=targets, + variable_keys=variable_keys, + variable_names=variable_names, + num_col=num_col, + num_row=num_row, + figsize=figsize, + ) + + estimates = plot_data.pop("estimates") + targets = plot_data.pop("targets") + + # Determine widths to compute coverage for + num_draws = estimates.shape[1] + widths = np.arange(0, num_draws + 2) / (num_draws + 1) + + # Compute empirical coverage with default parameters + coverage_data = compute_empirical_coverage( + estimates=estimates, + targets=targets, + widths=widths, + prob=0.95, + interval_type="central", + ) + + # Plot coverage for each parameter + for i, ax in enumerate(plot_data["axes"].flat): + if i >= plot_data["num_variables"]: + break + + width_rep = coverage_data["width_represented"][:, i] + coverage_est = coverage_data["coverage_estimates"][:, i] + coverage_low = coverage_data["coverage_lower"][:, i] + coverage_high = coverage_data["coverage_upper"][:, i] + + if difference: + # Compute differences for coverage difference plot + diff_est = coverage_est - width_rep + diff_low = coverage_low - width_rep + diff_high = coverage_high - width_rep + + # Plot confidence ribbon + ax.fill_between( + width_rep, + diff_low, + diff_high, + color="grey", + alpha=0.33, + label="95% Credible Interval", + ) + + # Plot ideal coverage difference line (y = 0) + ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage") + + # Plot empirical coverage difference + ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference") + + # Set axis limits + ax.set_xlim(0, 1) + + # Add legend to first subplot + if i == 0: + ax.legend(fontsize=tick_fontsize, loc="upper right") + else: + # Plot confidence ribbon + ax.fill_between( + width_rep, + coverage_low, + coverage_high, + color="grey", + alpha=0.33, + label="95% Credible Interval", + ) + + # Plot ideal coverage line (y = x) + ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage") + + # Plot empirical coverage + ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage") + + # Set axis limits + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + # Add legend to first subplot + if i == 0: + ax.legend(fontsize=tick_fontsize, loc="upper left") + + prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) + + # Add labels, titles, and set font sizes + ylabel = "Observed coverage difference" if difference else "Observed coverage" + add_titles_and_labels( + axes=plot_data["axes"], + num_row=plot_data["num_row"], + num_col=plot_data["num_col"], + title=plot_data["variable_names"], + xlabel="Central interval width", + ylabel=ylabel, + title_fontsize=title_fontsize, + label_fontsize=label_fontsize, + ) + + plot_data["fig"].tight_layout() + return plot_data["fig"] diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 4e9bdd8d5..a8d28a50a 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -71,6 +71,7 @@ prettify_subplots, make_quadratic, add_metric, + compute_empirical_coverage, ) from .serialization import serialize_value_or_type, deserialize_value_or_type diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 398d2d970..a181fe940 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -1,6 +1,8 @@ from typing import Sequence, Any, Mapping import numpy as np +from scipy.stats import beta + import matplotlib.pyplot as plt import seaborn as sns @@ -93,6 +95,106 @@ def prepare_plot_data( return plot_data +def compute_empirical_coverage( + estimates: np.ndarray, + targets: np.ndarray, + widths: np.ndarray, + prob: float = 0.95, + interval_type: str = "central", +) -> dict: + """ + Compute empirical coverage statistics for given interval widths. + + Parameters + ---------- + estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params) + The posterior draws obtained from num_datasets + targets : np.ndarray of shape (num_datasets, num_params) + The true parameter values used for generating num_datasets + widths : np.ndarray + Array of interval widths to compute coverage for (values between 0 and 1) + prob : float, optional, default: 0.95 + Confidence level for coverage confidence intervals + interval_type : str, optional, default: "central" + Type of credible interval. Either "central" or "leftmost" + + Returns + ------- + dict + Dictionary containing coverage statistics for each width and parameter + """ + num_datasets, num_draws, num_params = estimates.shape + num_widths = len(widths) + + # Initialize output arrays + coverage_estimates = np.zeros((num_widths, num_params)) + coverage_lower = np.zeros((num_widths, num_params)) + coverage_upper = np.zeros((num_widths, num_params)) + width_represented = np.zeros((num_widths, num_params)) + + for w_idx, width in enumerate(widths): + # Number of ranks to cover for this width + n_ranks_covered = round((num_draws + 1) * width) + + if interval_type == "central": + # Central interval: center around median + low_rank = round(num_draws / 2 - n_ranks_covered / 2) + high_rank = low_rank + n_ranks_covered - 1 + elif interval_type == "leftmost": + # Leftmost interval: start from minimum + low_rank = 0 + high_rank = n_ranks_covered - 1 + else: + raise ValueError("interval_type must be 'central' or 'leftmost'") + + # Ensure ranks are within valid bounds + low_rank = max(0, low_rank) + high_rank = min(num_draws - 1, high_rank) + + # Actual width represented by these ranks + actual_width = (high_rank - low_rank + 1) / (num_draws + 1) + + for p_idx in range(num_params): + # Sort posterior samples for each dataset and parameter + sorted_samples = np.sort(estimates[:, :, p_idx], axis=1) + + # Check if true value falls within credible interval + is_covered = (targets[:, p_idx] >= sorted_samples[:, low_rank]) & ( + targets[:, p_idx] <= sorted_samples[:, high_rank] + ) + + # Compute coverage estimate + num_covered = np.sum(is_covered) + coverage_est = num_covered / num_datasets + + # Compute confidence intervals using beta distribution + # Using Bayesian credible interval for binomial proportion + alpha_post = num_covered + 1 + beta_post = num_datasets - num_covered + 1 + + # Special handling for boundary cases + if actual_width == 0 or actual_width == 1: + # No variability possible + ci_low = actual_width + ci_high = actual_width + else: + ci_low = beta.ppf((1 - prob) / 2, alpha_post, beta_post) + ci_high = beta.ppf((1 + prob) / 2, alpha_post, beta_post) + + coverage_estimates[w_idx, p_idx] = coverage_est + coverage_lower[w_idx, p_idx] = ci_low + coverage_upper[w_idx, p_idx] = ci_high + width_represented[w_idx, p_idx] = actual_width + + return { + "coverage_estimates": coverage_estimates, + "coverage_lower": coverage_lower, + "coverage_upper": coverage_upper, + "width_represented": width_represented, + "widths": widths, + } + + def set_layout(num_total: int, num_row: int = None, num_col: int = None, stacked: bool = False): """ Determine the number of rows and columns in diagnostics visualizations. diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index 6f449787e..fd01f31dc 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -281,3 +281,21 @@ def test_mc_confusion_matrix(pred_models, true_models, model_names): assert out.axes[0].get_ylabel() == "True model" assert out.axes[0].get_xlabel() == "Predicted model" assert out.axes[0].get_title() == "Confusion Matrix" + + +def test_coverage(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.coverage(random_estimates, random_targets) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "beta_1" + assert out.axes[0].get_xlabel() == "Central interval width" + assert out.axes[0].get_ylabel() == "Observed coverage" + + +def test_coverage_diff(random_estimates, random_targets): + # basic functionality: automatic variable names + out = bf.diagnostics.plots.coverage(random_estimates, random_targets, difference=True) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "beta_1" + assert out.axes[0].get_xlabel() == "Central interval width" + assert out.axes[0].get_ylabel() == "Observed coverage difference"