Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .calibration_ecdf import calibration_ecdf
from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles
from .calibration_histogram import calibration_histogram
from .coverage import coverage
from .loss import loss
from .mc_calibration import mc_calibration
from .mc_confusion_matrix import mc_confusion_matrix
Expand Down
183 changes: 183 additions & 0 deletions bayesflow/diagnostics/plots/coverage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from collections.abc import Sequence, Mapping

import matplotlib.pyplot as plt
import numpy as np

from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage


def coverage(
estimates: Mapping[str, np.ndarray] | np.ndarray,
targets: Mapping[str, np.ndarray] | np.ndarray,
difference: bool = False,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
title_fontsize: int = 18,
tick_fontsize: int = 12,
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
) -> plt.Figure:
"""
Creates coverage plots showing empirical coverage of posterior credible intervals.

The empirical coverage shows the coverage (proportion of true variable values that fall within the interval)
of the central posterior credible intervals.
A well-calibrated model would have coverage exactly match interval width (i.e. 95%
credible interval contains the true value 95% of the time) as shown by the diagonal line.

The coverage is accompanied by credible intervals for the coverage (gray ribbon).
These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior.
For more details on the Beta-Binomial model, see Chapter 2 of Bayesian Data Analysis (2013, 3rd ed.) by
Gelman A., Carlin J., Stern H., Dunson D., Vehtari A., & Rubin D.

Parameters
----------
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
The posterior draws obtained from num_datasets
targets : np.ndarray of shape (num_datasets, num_params)
The true parameter values used for generating num_datasets
difference : bool, optional, default: False
If True, plots the difference between empirical coverage and ideal coverage
(coverage - width), making deviations from ideal calibration more visible.
If False, plots the standard coverage plot.
variable_keys : list or None, optional, default: None
Select keys from the dictionaries provided in estimates and targets.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
figsize : tuple or None, optional, default: None
The figure size passed to the matplotlib constructor. Inferred if None.
label_fontsize : int, optional, default: 16
The font size of the y-label and x-label text
title_fontsize : int, optional, default: 18
The font size of the title text
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
color : str, optional, default: '#132a70'
The color for the coverage line
num_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.

Returns
-------
f : plt.Figure - the figure instance for optional saving

Raises
------
ShapeError
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.

"""

# Gather plot data and metadata into a dictionary
plot_data = prepare_plot_data(
estimates=estimates,
targets=targets,
variable_keys=variable_keys,
variable_names=variable_names,
num_col=num_col,
num_row=num_row,
figsize=figsize,
)

estimates = plot_data.pop("estimates")
targets = plot_data.pop("targets")

# Determine widths to compute coverage for
num_draws = estimates.shape[1]
widths = np.arange(0, num_draws + 2) / (num_draws + 1)

# Compute empirical coverage with default parameters
coverage_data = compute_empirical_coverage(
estimates=estimates,
targets=targets,
widths=widths,
prob=0.95,
interval_type="central",
)

# Plot coverage for each parameter
for i, ax in enumerate(plot_data["axes"].flat):
if i >= plot_data["num_variables"]:
break

width_rep = coverage_data["width_represented"][:, i]
coverage_est = coverage_data["coverage_estimates"][:, i]
coverage_low = coverage_data["coverage_lower"][:, i]
coverage_high = coverage_data["coverage_upper"][:, i]

if difference:
# Compute differences for coverage difference plot
diff_est = coverage_est - width_rep
diff_low = coverage_low - width_rep
diff_high = coverage_high - width_rep

# Plot confidence ribbon
ax.fill_between(
width_rep,
diff_low,
diff_high,
color="grey",
alpha=0.33,
label="95% Credible Interval",
)

# Plot ideal coverage difference line (y = 0)
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")

# Plot empirical coverage difference
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")

# Set axis limits
ax.set_xlim(0, 1)

# Add legend to first subplot
if i == 0:
ax.legend(fontsize=tick_fontsize, loc="upper right")
else:
# Plot confidence ribbon
ax.fill_between(
width_rep,
coverage_low,
coverage_high,
color="grey",
alpha=0.33,
label="95% Credible Interval",
)

# Plot ideal coverage line (y = x)
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")

# Plot empirical coverage
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")

# Set axis limits
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Add legend to first subplot
if i == 0:
ax.legend(fontsize=tick_fontsize, loc="upper left")

prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

# Add labels, titles, and set font sizes
ylabel = "Observed coverage difference" if difference else "Observed coverage"
add_titles_and_labels(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["variable_names"],
xlabel="Central interval width",
ylabel=ylabel,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
)

plot_data["fig"].tight_layout()
return plot_data["fig"]
1 change: 1 addition & 0 deletions bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
prettify_subplots,
make_quadratic,
add_metric,
compute_empirical_coverage,
)
from .serialization import serialize_value_or_type, deserialize_value_or_type

Expand Down
102 changes: 102 additions & 0 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Sequence, Any, Mapping

import numpy as np
from scipy.stats import beta

import matplotlib.pyplot as plt
import seaborn as sns

Expand Down Expand Up @@ -93,6 +95,106 @@ def prepare_plot_data(
return plot_data


def compute_empirical_coverage(
estimates: np.ndarray,
targets: np.ndarray,
widths: np.ndarray,
prob: float = 0.95,
interval_type: str = "central",
) -> dict:
"""
Compute empirical coverage statistics for given interval widths.

Parameters
----------
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
The posterior draws obtained from num_datasets
targets : np.ndarray of shape (num_datasets, num_params)
The true parameter values used for generating num_datasets
widths : np.ndarray
Array of interval widths to compute coverage for (values between 0 and 1)
prob : float, optional, default: 0.95
Confidence level for coverage confidence intervals
interval_type : str, optional, default: "central"
Type of credible interval. Either "central" or "leftmost"

Returns
-------
dict
Dictionary containing coverage statistics for each width and parameter
"""
num_datasets, num_draws, num_params = estimates.shape
num_widths = len(widths)

# Initialize output arrays
coverage_estimates = np.zeros((num_widths, num_params))
coverage_lower = np.zeros((num_widths, num_params))
coverage_upper = np.zeros((num_widths, num_params))
width_represented = np.zeros((num_widths, num_params))

for w_idx, width in enumerate(widths):
# Number of ranks to cover for this width
n_ranks_covered = round((num_draws + 1) * width)

if interval_type == "central":
# Central interval: center around median
low_rank = round(num_draws / 2 - n_ranks_covered / 2)
high_rank = low_rank + n_ranks_covered - 1
elif interval_type == "leftmost":
# Leftmost interval: start from minimum
low_rank = 0
high_rank = n_ranks_covered - 1
else:
raise ValueError("interval_type must be 'central' or 'leftmost'")

# Ensure ranks are within valid bounds
low_rank = max(0, low_rank)
high_rank = min(num_draws - 1, high_rank)

# Actual width represented by these ranks
actual_width = (high_rank - low_rank + 1) / (num_draws + 1)

for p_idx in range(num_params):
# Sort posterior samples for each dataset and parameter
sorted_samples = np.sort(estimates[:, :, p_idx], axis=1)

# Check if true value falls within credible interval
is_covered = (targets[:, p_idx] >= sorted_samples[:, low_rank]) & (
targets[:, p_idx] <= sorted_samples[:, high_rank]
)

# Compute coverage estimate
num_covered = np.sum(is_covered)
coverage_est = num_covered / num_datasets

# Compute confidence intervals using beta distribution
# Using Bayesian credible interval for binomial proportion
alpha_post = num_covered + 1
beta_post = num_datasets - num_covered + 1

# Special handling for boundary cases
if actual_width == 0 or actual_width == 1:
# No variability possible
ci_low = actual_width
ci_high = actual_width
else:
ci_low = beta.ppf((1 - prob) / 2, alpha_post, beta_post)
ci_high = beta.ppf((1 + prob) / 2, alpha_post, beta_post)

coverage_estimates[w_idx, p_idx] = coverage_est
coverage_lower[w_idx, p_idx] = ci_low
coverage_upper[w_idx, p_idx] = ci_high
width_represented[w_idx, p_idx] = actual_width

return {
"coverage_estimates": coverage_estimates,
"coverage_lower": coverage_lower,
"coverage_upper": coverage_upper,
"width_represented": width_represented,
"widths": widths,
}


def set_layout(num_total: int, num_row: int = None, num_col: int = None, stacked: bool = False):
"""
Determine the number of rows and columns in diagnostics visualizations.
Expand Down
18 changes: 18 additions & 0 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,21 @@ def test_mc_confusion_matrix(pred_models, true_models, model_names):
assert out.axes[0].get_ylabel() == "True model"
assert out.axes[0].get_xlabel() == "Predicted model"
assert out.axes[0].get_title() == "Confusion Matrix"


def test_coverage(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.coverage(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[1].title._text == "beta_1"
assert out.axes[0].get_xlabel() == "Central interval width"
assert out.axes[0].get_ylabel() == "Observed coverage"


def test_coverage_diff(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.coverage(random_estimates, random_targets, difference=True)
assert len(out.axes) == num_variables(random_estimates)
assert out.axes[1].title._text == "beta_1"
assert out.axes[0].get_xlabel() == "Central interval width"
assert out.axes[0].get_ylabel() == "Observed coverage difference"
Loading