Skip to content

Commit 8ef46f6

Browse files
Add empirical coverage diagnostic plots (#579)
* add empirical coverage diagnostic plots * merging the plotting functions via difference arg and add beta binomial reference * adjust tests and diagnostics init
1 parent f1cdbbd commit 8ef46f6

File tree

5 files changed

+305
-0
lines changed

5 files changed

+305
-0
lines changed

bayesflow/diagnostics/plots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .calibration_ecdf import calibration_ecdf
22
from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles
33
from .calibration_histogram import calibration_histogram
4+
from .coverage import coverage
45
from .loss import loss
56
from .mc_calibration import mc_calibration
67
from .mc_confusion_matrix import mc_confusion_matrix
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from collections.abc import Sequence, Mapping
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage
7+
8+
9+
def coverage(
10+
estimates: Mapping[str, np.ndarray] | np.ndarray,
11+
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
difference: bool = False,
13+
variable_keys: Sequence[str] = None,
14+
variable_names: Sequence[str] = None,
15+
figsize: Sequence[int] = None,
16+
label_fontsize: int = 16,
17+
title_fontsize: int = 18,
18+
tick_fontsize: int = 12,
19+
color: str = "#132a70",
20+
num_col: int = None,
21+
num_row: int = None,
22+
) -> plt.Figure:
23+
"""
24+
Creates coverage plots showing empirical coverage of posterior credible intervals.
25+
26+
The empirical coverage shows the coverage (proportion of true variable values that fall within the interval)
27+
of the central posterior credible intervals.
28+
A well-calibrated model would have coverage exactly match interval width (i.e. 95%
29+
credible interval contains the true value 95% of the time) as shown by the diagonal line.
30+
31+
The coverage is accompanied by credible intervals for the coverage (gray ribbon).
32+
These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior.
33+
For more details on the Beta-Binomial model, see Chapter 2 of Bayesian Data Analysis (2013, 3rd ed.) by
34+
Gelman A., Carlin J., Stern H., Dunson D., Vehtari A., & Rubin D.
35+
36+
Parameters
37+
----------
38+
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
39+
The posterior draws obtained from num_datasets
40+
targets : np.ndarray of shape (num_datasets, num_params)
41+
The true parameter values used for generating num_datasets
42+
difference : bool, optional, default: False
43+
If True, plots the difference between empirical coverage and ideal coverage
44+
(coverage - width), making deviations from ideal calibration more visible.
45+
If False, plots the standard coverage plot.
46+
variable_keys : list or None, optional, default: None
47+
Select keys from the dictionaries provided in estimates and targets.
48+
By default, select all keys.
49+
variable_names : list or None, optional, default: None
50+
The parameter names for nice plot titles. Inferred if None
51+
figsize : tuple or None, optional, default: None
52+
The figure size passed to the matplotlib constructor. Inferred if None.
53+
label_fontsize : int, optional, default: 16
54+
The font size of the y-label and x-label text
55+
title_fontsize : int, optional, default: 18
56+
The font size of the title text
57+
tick_fontsize : int, optional, default: 12
58+
The font size of the axis ticklabels
59+
color : str, optional, default: '#132a70'
60+
The color for the coverage line
61+
num_row : int, optional, default: None
62+
The number of rows for the subplots. Dynamically determined if None.
63+
num_col : int, optional, default: None
64+
The number of columns for the subplots. Dynamically determined if None.
65+
66+
Returns
67+
-------
68+
f : plt.Figure - the figure instance for optional saving
69+
70+
Raises
71+
------
72+
ShapeError
73+
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
74+
75+
"""
76+
77+
# Gather plot data and metadata into a dictionary
78+
plot_data = prepare_plot_data(
79+
estimates=estimates,
80+
targets=targets,
81+
variable_keys=variable_keys,
82+
variable_names=variable_names,
83+
num_col=num_col,
84+
num_row=num_row,
85+
figsize=figsize,
86+
)
87+
88+
estimates = plot_data.pop("estimates")
89+
targets = plot_data.pop("targets")
90+
91+
# Determine widths to compute coverage for
92+
num_draws = estimates.shape[1]
93+
widths = np.arange(0, num_draws + 2) / (num_draws + 1)
94+
95+
# Compute empirical coverage with default parameters
96+
coverage_data = compute_empirical_coverage(
97+
estimates=estimates,
98+
targets=targets,
99+
widths=widths,
100+
prob=0.95,
101+
interval_type="central",
102+
)
103+
104+
# Plot coverage for each parameter
105+
for i, ax in enumerate(plot_data["axes"].flat):
106+
if i >= plot_data["num_variables"]:
107+
break
108+
109+
width_rep = coverage_data["width_represented"][:, i]
110+
coverage_est = coverage_data["coverage_estimates"][:, i]
111+
coverage_low = coverage_data["coverage_lower"][:, i]
112+
coverage_high = coverage_data["coverage_upper"][:, i]
113+
114+
if difference:
115+
# Compute differences for coverage difference plot
116+
diff_est = coverage_est - width_rep
117+
diff_low = coverage_low - width_rep
118+
diff_high = coverage_high - width_rep
119+
120+
# Plot confidence ribbon
121+
ax.fill_between(
122+
width_rep,
123+
diff_low,
124+
diff_high,
125+
color="grey",
126+
alpha=0.33,
127+
label="95% Credible Interval",
128+
)
129+
130+
# Plot ideal coverage difference line (y = 0)
131+
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
132+
133+
# Plot empirical coverage difference
134+
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
135+
136+
# Set axis limits
137+
ax.set_xlim(0, 1)
138+
139+
# Add legend to first subplot
140+
if i == 0:
141+
ax.legend(fontsize=tick_fontsize, loc="upper right")
142+
else:
143+
# Plot confidence ribbon
144+
ax.fill_between(
145+
width_rep,
146+
coverage_low,
147+
coverage_high,
148+
color="grey",
149+
alpha=0.33,
150+
label="95% Credible Interval",
151+
)
152+
153+
# Plot ideal coverage line (y = x)
154+
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
155+
156+
# Plot empirical coverage
157+
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
158+
159+
# Set axis limits
160+
ax.set_xlim(0, 1)
161+
ax.set_ylim(0, 1)
162+
163+
# Add legend to first subplot
164+
if i == 0:
165+
ax.legend(fontsize=tick_fontsize, loc="upper left")
166+
167+
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
168+
169+
# Add labels, titles, and set font sizes
170+
ylabel = "Observed coverage difference" if difference else "Observed coverage"
171+
add_titles_and_labels(
172+
axes=plot_data["axes"],
173+
num_row=plot_data["num_row"],
174+
num_col=plot_data["num_col"],
175+
title=plot_data["variable_names"],
176+
xlabel="Central interval width",
177+
ylabel=ylabel,
178+
title_fontsize=title_fontsize,
179+
label_fontsize=label_fontsize,
180+
)
181+
182+
plot_data["fig"].tight_layout()
183+
return plot_data["fig"]

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
prettify_subplots,
7272
make_quadratic,
7373
add_metric,
74+
compute_empirical_coverage,
7475
)
7576
from .serialization import serialize_value_or_type, deserialize_value_or_type
7677

bayesflow/utils/plot_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Sequence, Any, Mapping
22

33
import numpy as np
4+
from scipy.stats import beta
5+
46
import matplotlib.pyplot as plt
57
import seaborn as sns
68

@@ -93,6 +95,106 @@ def prepare_plot_data(
9395
return plot_data
9496

9597

98+
def compute_empirical_coverage(
99+
estimates: np.ndarray,
100+
targets: np.ndarray,
101+
widths: np.ndarray,
102+
prob: float = 0.95,
103+
interval_type: str = "central",
104+
) -> dict:
105+
"""
106+
Compute empirical coverage statistics for given interval widths.
107+
108+
Parameters
109+
----------
110+
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
111+
The posterior draws obtained from num_datasets
112+
targets : np.ndarray of shape (num_datasets, num_params)
113+
The true parameter values used for generating num_datasets
114+
widths : np.ndarray
115+
Array of interval widths to compute coverage for (values between 0 and 1)
116+
prob : float, optional, default: 0.95
117+
Confidence level for coverage confidence intervals
118+
interval_type : str, optional, default: "central"
119+
Type of credible interval. Either "central" or "leftmost"
120+
121+
Returns
122+
-------
123+
dict
124+
Dictionary containing coverage statistics for each width and parameter
125+
"""
126+
num_datasets, num_draws, num_params = estimates.shape
127+
num_widths = len(widths)
128+
129+
# Initialize output arrays
130+
coverage_estimates = np.zeros((num_widths, num_params))
131+
coverage_lower = np.zeros((num_widths, num_params))
132+
coverage_upper = np.zeros((num_widths, num_params))
133+
width_represented = np.zeros((num_widths, num_params))
134+
135+
for w_idx, width in enumerate(widths):
136+
# Number of ranks to cover for this width
137+
n_ranks_covered = round((num_draws + 1) * width)
138+
139+
if interval_type == "central":
140+
# Central interval: center around median
141+
low_rank = round(num_draws / 2 - n_ranks_covered / 2)
142+
high_rank = low_rank + n_ranks_covered - 1
143+
elif interval_type == "leftmost":
144+
# Leftmost interval: start from minimum
145+
low_rank = 0
146+
high_rank = n_ranks_covered - 1
147+
else:
148+
raise ValueError("interval_type must be 'central' or 'leftmost'")
149+
150+
# Ensure ranks are within valid bounds
151+
low_rank = max(0, low_rank)
152+
high_rank = min(num_draws - 1, high_rank)
153+
154+
# Actual width represented by these ranks
155+
actual_width = (high_rank - low_rank + 1) / (num_draws + 1)
156+
157+
for p_idx in range(num_params):
158+
# Sort posterior samples for each dataset and parameter
159+
sorted_samples = np.sort(estimates[:, :, p_idx], axis=1)
160+
161+
# Check if true value falls within credible interval
162+
is_covered = (targets[:, p_idx] >= sorted_samples[:, low_rank]) & (
163+
targets[:, p_idx] <= sorted_samples[:, high_rank]
164+
)
165+
166+
# Compute coverage estimate
167+
num_covered = np.sum(is_covered)
168+
coverage_est = num_covered / num_datasets
169+
170+
# Compute confidence intervals using beta distribution
171+
# Using Bayesian credible interval for binomial proportion
172+
alpha_post = num_covered + 1
173+
beta_post = num_datasets - num_covered + 1
174+
175+
# Special handling for boundary cases
176+
if actual_width == 0 or actual_width == 1:
177+
# No variability possible
178+
ci_low = actual_width
179+
ci_high = actual_width
180+
else:
181+
ci_low = beta.ppf((1 - prob) / 2, alpha_post, beta_post)
182+
ci_high = beta.ppf((1 + prob) / 2, alpha_post, beta_post)
183+
184+
coverage_estimates[w_idx, p_idx] = coverage_est
185+
coverage_lower[w_idx, p_idx] = ci_low
186+
coverage_upper[w_idx, p_idx] = ci_high
187+
width_represented[w_idx, p_idx] = actual_width
188+
189+
return {
190+
"coverage_estimates": coverage_estimates,
191+
"coverage_lower": coverage_lower,
192+
"coverage_upper": coverage_upper,
193+
"width_represented": width_represented,
194+
"widths": widths,
195+
}
196+
197+
96198
def set_layout(num_total: int, num_row: int = None, num_col: int = None, stacked: bool = False):
97199
"""
98200
Determine the number of rows and columns in diagnostics visualizations.

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,21 @@ def test_mc_confusion_matrix(pred_models, true_models, model_names):
281281
assert out.axes[0].get_ylabel() == "True model"
282282
assert out.axes[0].get_xlabel() == "Predicted model"
283283
assert out.axes[0].get_title() == "Confusion Matrix"
284+
285+
286+
def test_coverage(random_estimates, random_targets):
287+
# basic functionality: automatic variable names
288+
out = bf.diagnostics.plots.coverage(random_estimates, random_targets)
289+
assert len(out.axes) == num_variables(random_estimates)
290+
assert out.axes[1].title._text == "beta_1"
291+
assert out.axes[0].get_xlabel() == "Central interval width"
292+
assert out.axes[0].get_ylabel() == "Observed coverage"
293+
294+
295+
def test_coverage_diff(random_estimates, random_targets):
296+
# basic functionality: automatic variable names
297+
out = bf.diagnostics.plots.coverage(random_estimates, random_targets, difference=True)
298+
assert len(out.axes) == num_variables(random_estimates)
299+
assert out.axes[1].title._text == "beta_1"
300+
assert out.axes[0].get_xlabel() == "Central interval width"
301+
assert out.axes[0].get_ylabel() == "Observed coverage difference"

0 commit comments

Comments
 (0)