Skip to content

Commit 7430d8c

Browse files
committed
add empirical coverage diagnostic plots
1 parent 08ed995 commit 7430d8c

File tree

5 files changed

+409
-0
lines changed

5 files changed

+409
-0
lines changed

bayesflow/diagnostics/plots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .calibration_ecdf import calibration_ecdf
22
from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles
33
from .calibration_histogram import calibration_histogram
4+
from .coverage import coverage, coverage_diff
45
from .loss import loss
56
from .mc_calibration import mc_calibration
67
from .mc_confusion_matrix import mc_confusion_matrix
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from collections.abc import Sequence, Mapping
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage
7+
8+
9+
def coverage(
10+
estimates: Mapping[str, np.ndarray] | np.ndarray,
11+
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
variable_keys: Sequence[str] = None,
13+
variable_names: Sequence[str] = None,
14+
figsize: Sequence[int] = None,
15+
label_fontsize: int = 16,
16+
title_fontsize: int = 18,
17+
tick_fontsize: int = 12,
18+
color: str = "#132a70",
19+
num_col: int = None,
20+
num_row: int = None,
21+
) -> plt.Figure:
22+
"""
23+
Creates coverage plots showing empirical coverage of posterior credible intervals.
24+
25+
The empirical coverage shows the coverage (proportion of true variable values that fall within the interval)
26+
of the central posterior credible intervals.
27+
A well-calibrated model would have coverage exactly match interval width (i.e. 95%
28+
credible interval contains the true value 95% of the time) as shown by the diagonal line.
29+
30+
The coverage is accompanied by credible intervals for the coverage (gray ribbon).
31+
These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior.
32+
33+
Parameters
34+
----------
35+
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
36+
The posterior draws obtained from num_datasets
37+
targets : np.ndarray of shape (num_datasets, num_params)
38+
The true parameter values used for generating num_datasets
39+
variable_keys : list or None, optional, default: None
40+
Select keys from the dictionaries provided in estimates and targets.
41+
By default, select all keys.
42+
variable_names : list or None, optional, default: None
43+
The parameter names for nice plot titles. Inferred if None
44+
figsize : tuple or None, optional, default: None
45+
The figure size passed to the matplotlib constructor. Inferred if None.
46+
label_fontsize : int, optional, default: 16
47+
The font size of the y-label and x-label text
48+
title_fontsize : int, optional, default: 18
49+
The font size of the title text
50+
tick_fontsize : int, optional, default: 12
51+
The font size of the axis ticklabels
52+
color : str, optional, default: '#132a70'
53+
The color for the coverage line
54+
num_row : int, optional, default: None
55+
The number of rows for the subplots. Dynamically determined if None.
56+
num_col : int, optional, default: None
57+
The number of columns for the subplots. Dynamically determined if None.
58+
59+
Returns
60+
-------
61+
f : plt.Figure - the figure instance for optional saving
62+
63+
Raises
64+
------
65+
ShapeError
66+
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
67+
68+
"""
69+
70+
# Gather plot data and metadata into a dictionary
71+
plot_data = prepare_plot_data(
72+
estimates=estimates,
73+
targets=targets,
74+
variable_keys=variable_keys,
75+
variable_names=variable_names,
76+
num_col=num_col,
77+
num_row=num_row,
78+
figsize=figsize,
79+
)
80+
81+
estimates = plot_data.pop("estimates")
82+
targets = plot_data.pop("targets")
83+
84+
# Determine widths to compute coverage for
85+
num_draws = estimates.shape[1]
86+
widths = np.arange(0, num_draws + 2) / (num_draws + 1)
87+
88+
# Compute empirical coverage with default parameters
89+
coverage_data = compute_empirical_coverage(
90+
estimates=estimates,
91+
targets=targets,
92+
widths=widths,
93+
prob=0.95,
94+
interval_type="central",
95+
)
96+
97+
# Plot coverage for each parameter
98+
for i, ax in enumerate(plot_data["axes"].flat):
99+
if i >= plot_data["num_variables"]:
100+
break
101+
102+
width_rep = coverage_data["width_represented"][:, i]
103+
coverage_est = coverage_data["coverage_estimates"][:, i]
104+
coverage_low = coverage_data["coverage_lower"][:, i]
105+
coverage_high = coverage_data["coverage_upper"][:, i]
106+
107+
# Plot confidence ribbon
108+
ax.fill_between(
109+
width_rep,
110+
coverage_low,
111+
coverage_high,
112+
color="grey",
113+
alpha=0.33,
114+
label="95% Credible Interval",
115+
)
116+
117+
# Plot ideal coverage line (y = x)
118+
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
119+
120+
# Plot empirical coverage
121+
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
122+
123+
# Set axis limits
124+
ax.set_xlim(0, 1)
125+
ax.set_ylim(0, 1)
126+
127+
# Add legend to first subplot
128+
if i == 0:
129+
ax.legend(fontsize=tick_fontsize, loc="upper left")
130+
131+
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
132+
133+
# Add labels, titles, and set font sizes
134+
add_titles_and_labels(
135+
axes=plot_data["axes"],
136+
num_row=plot_data["num_row"],
137+
num_col=plot_data["num_col"],
138+
title=plot_data["variable_names"],
139+
xlabel="Central interval width",
140+
ylabel="Observed coverage",
141+
title_fontsize=title_fontsize,
142+
label_fontsize=label_fontsize,
143+
)
144+
145+
plot_data["fig"].tight_layout()
146+
return plot_data["fig"]
147+
148+
149+
def coverage_diff(
150+
estimates: Mapping[str, np.ndarray] | np.ndarray,
151+
targets: Mapping[str, np.ndarray] | np.ndarray,
152+
variable_keys: Sequence[str] = None,
153+
variable_names: Sequence[str] = None,
154+
figsize: Sequence[int] = None,
155+
label_fontsize: int = 16,
156+
title_fontsize: int = 18,
157+
tick_fontsize: int = 12,
158+
color: str = "#132a70",
159+
num_col: int = None,
160+
num_row: int = None,
161+
) -> plt.Figure:
162+
"""
163+
Creates coverage difference plots showing the difference between empirical coverage
164+
and ideal coverage of posterior credible intervals.
165+
166+
This plot shows coverage - width, making deviations from ideal calibration
167+
more visible than the standard coverage plot.
168+
For more details, see the documentation of the standard coverage plot.
169+
170+
Parameters
171+
----------
172+
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
173+
The posterior draws obtained from num_datasets
174+
targets : np.ndarray of shape (num_datasets, num_params)
175+
The true parameter values used for generating num_datasets
176+
variable_keys : list or None, optional, default: None
177+
Select keys from the dictionaries provided in estimates and targets.
178+
By default, select all keys.
179+
variable_names : list or None, optional, default: None
180+
The parameter names for nice plot titles. Inferred if None
181+
figsize : tuple or None, optional, default: None
182+
The figure size passed to the matplotlib constructor. Inferred if None.
183+
label_fontsize : int, optional, default: 16
184+
The font size of the y-label and x-label text
185+
title_fontsize : int, optional, default: 18
186+
The font size of the title text
187+
tick_fontsize : int, optional, default: 12
188+
The font size of the axis ticklabels
189+
color : str, optional, default: '#132a70'
190+
The color for the coverage difference line
191+
num_row : int, optional, default: None
192+
The number of rows for the subplots. Dynamically determined if None.
193+
num_col : int, optional, default: None
194+
The number of columns for the subplots. Dynamically determined if None.
195+
196+
Returns
197+
-------
198+
f : plt.Figure - the figure instance for optional saving
199+
200+
Raises
201+
------
202+
ShapeError
203+
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
204+
205+
"""
206+
207+
# Gather plot data and metadata into a dictionary
208+
plot_data = prepare_plot_data(
209+
estimates=estimates,
210+
targets=targets,
211+
variable_keys=variable_keys,
212+
variable_names=variable_names,
213+
num_col=num_col,
214+
num_row=num_row,
215+
figsize=figsize,
216+
)
217+
218+
estimates = plot_data.pop("estimates")
219+
targets = plot_data.pop("targets")
220+
221+
# Determine widths to compute coverage for
222+
num_draws = estimates.shape[1]
223+
widths = np.arange(0, num_draws + 2) / (num_draws + 1)
224+
225+
# Compute empirical coverage with default parameters
226+
coverage_data = compute_empirical_coverage(
227+
estimates=estimates,
228+
targets=targets,
229+
widths=widths,
230+
prob=0.95,
231+
interval_type="central",
232+
)
233+
234+
# Plot coverage difference for each parameter
235+
for i, ax in enumerate(plot_data["axes"].flat):
236+
if i >= plot_data["num_variables"]:
237+
break
238+
239+
width_rep = coverage_data["width_represented"][:, i]
240+
coverage_est = coverage_data["coverage_estimates"][:, i]
241+
coverage_low = coverage_data["coverage_lower"][:, i]
242+
coverage_high = coverage_data["coverage_upper"][:, i]
243+
244+
# Compute differences
245+
diff_est = coverage_est - width_rep
246+
diff_low = coverage_low - width_rep
247+
diff_high = coverage_high - width_rep
248+
249+
# Plot confidence ribbon
250+
ax.fill_between(
251+
width_rep,
252+
diff_low,
253+
diff_high,
254+
color="grey",
255+
alpha=0.33,
256+
label="95% Credible Interval",
257+
)
258+
259+
# Plot ideal coverage difference line (y = 0)
260+
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
261+
262+
# Plot empirical coverage difference
263+
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
264+
265+
# Set axis limits
266+
ax.set_xlim(0, 1)
267+
268+
# Add legend to first subplot
269+
if i == 0:
270+
ax.legend(fontsize=tick_fontsize, loc="upper right")
271+
272+
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
273+
274+
# Add labels, titles, and set font sizes
275+
add_titles_and_labels(
276+
axes=plot_data["axes"],
277+
num_row=plot_data["num_row"],
278+
num_col=plot_data["num_col"],
279+
title=plot_data["variable_names"],
280+
xlabel="Central interval width",
281+
ylabel="Coverage difference",
282+
title_fontsize=title_fontsize,
283+
label_fontsize=label_fontsize,
284+
)
285+
286+
plot_data["fig"].tight_layout()
287+
return plot_data["fig"]

bayesflow/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
prettify_subplots,
7272
make_quadratic,
7373
add_metric,
74+
compute_empirical_coverage,
7475
)
7576
from .serialization import serialize_value_or_type, deserialize_value_or_type
7677

0 commit comments

Comments
 (0)