Skip to content

Commit bbf3da7

Browse files
committed
merging the plotting functions via difference arg and add beta binomial reference
1 parent 7430d8c commit bbf3da7

File tree

1 file changed

+61
-165
lines changed

1 file changed

+61
-165
lines changed

bayesflow/diagnostics/plots/coverage.py

Lines changed: 61 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def coverage(
1010
estimates: Mapping[str, np.ndarray] | np.ndarray,
1111
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
difference: bool = False,
1213
variable_keys: Sequence[str] = None,
1314
variable_names: Sequence[str] = None,
1415
figsize: Sequence[int] = None,
@@ -29,13 +30,19 @@ def coverage(
2930
3031
The coverage is accompanied by credible intervals for the coverage (gray ribbon).
3132
These are computed via the (conjugate) Beta-Binomial model for binomial proportions with a uniform prior.
33+
For more details on the Beta-Binomial model, see Chapter 2 of Bayesian Data Analysis (2013, 3rd ed.) by
34+
Gelman A., Carlin J., Stern H., Dunson D., Vehtari A., & Rubin D.
3235
3336
Parameters
3437
----------
3538
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
3639
The posterior draws obtained from num_datasets
3740
targets : np.ndarray of shape (num_datasets, num_params)
3841
The true parameter values used for generating num_datasets
42+
difference : bool, optional, default: False
43+
If True, plots the difference between empirical coverage and ideal coverage
44+
(coverage - width), making deviations from ideal calibration more visible.
45+
If False, plots the standard coverage plot.
3946
variable_keys : list or None, optional, default: None
4047
Select keys from the dictionaries provided in estimates and targets.
4148
By default, select all keys.
@@ -104,181 +111,70 @@ def coverage(
104111
coverage_low = coverage_data["coverage_lower"][:, i]
105112
coverage_high = coverage_data["coverage_upper"][:, i]
106113

107-
# Plot confidence ribbon
108-
ax.fill_between(
109-
width_rep,
110-
coverage_low,
111-
coverage_high,
112-
color="grey",
113-
alpha=0.33,
114-
label="95% Credible Interval",
115-
)
116-
117-
# Plot ideal coverage line (y = x)
118-
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
119-
120-
# Plot empirical coverage
121-
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
122-
123-
# Set axis limits
124-
ax.set_xlim(0, 1)
125-
ax.set_ylim(0, 1)
126-
127-
# Add legend to first subplot
128-
if i == 0:
129-
ax.legend(fontsize=tick_fontsize, loc="upper left")
130-
131-
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
132-
133-
# Add labels, titles, and set font sizes
134-
add_titles_and_labels(
135-
axes=plot_data["axes"],
136-
num_row=plot_data["num_row"],
137-
num_col=plot_data["num_col"],
138-
title=plot_data["variable_names"],
139-
xlabel="Central interval width",
140-
ylabel="Observed coverage",
141-
title_fontsize=title_fontsize,
142-
label_fontsize=label_fontsize,
143-
)
144-
145-
plot_data["fig"].tight_layout()
146-
return plot_data["fig"]
147-
148-
149-
def coverage_diff(
150-
estimates: Mapping[str, np.ndarray] | np.ndarray,
151-
targets: Mapping[str, np.ndarray] | np.ndarray,
152-
variable_keys: Sequence[str] = None,
153-
variable_names: Sequence[str] = None,
154-
figsize: Sequence[int] = None,
155-
label_fontsize: int = 16,
156-
title_fontsize: int = 18,
157-
tick_fontsize: int = 12,
158-
color: str = "#132a70",
159-
num_col: int = None,
160-
num_row: int = None,
161-
) -> plt.Figure:
162-
"""
163-
Creates coverage difference plots showing the difference between empirical coverage
164-
and ideal coverage of posterior credible intervals.
165-
166-
This plot shows coverage - width, making deviations from ideal calibration
167-
more visible than the standard coverage plot.
168-
For more details, see the documentation of the standard coverage plot.
169-
170-
Parameters
171-
----------
172-
estimates : np.ndarray of shape (num_datasets, num_post_draws, num_params)
173-
The posterior draws obtained from num_datasets
174-
targets : np.ndarray of shape (num_datasets, num_params)
175-
The true parameter values used for generating num_datasets
176-
variable_keys : list or None, optional, default: None
177-
Select keys from the dictionaries provided in estimates and targets.
178-
By default, select all keys.
179-
variable_names : list or None, optional, default: None
180-
The parameter names for nice plot titles. Inferred if None
181-
figsize : tuple or None, optional, default: None
182-
The figure size passed to the matplotlib constructor. Inferred if None.
183-
label_fontsize : int, optional, default: 16
184-
The font size of the y-label and x-label text
185-
title_fontsize : int, optional, default: 18
186-
The font size of the title text
187-
tick_fontsize : int, optional, default: 12
188-
The font size of the axis ticklabels
189-
color : str, optional, default: '#132a70'
190-
The color for the coverage difference line
191-
num_row : int, optional, default: None
192-
The number of rows for the subplots. Dynamically determined if None.
193-
num_col : int, optional, default: None
194-
The number of columns for the subplots. Dynamically determined if None.
195-
196-
Returns
197-
-------
198-
f : plt.Figure - the figure instance for optional saving
199-
200-
Raises
201-
------
202-
ShapeError
203-
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
204-
205-
"""
206-
207-
# Gather plot data and metadata into a dictionary
208-
plot_data = prepare_plot_data(
209-
estimates=estimates,
210-
targets=targets,
211-
variable_keys=variable_keys,
212-
variable_names=variable_names,
213-
num_col=num_col,
214-
num_row=num_row,
215-
figsize=figsize,
216-
)
217-
218-
estimates = plot_data.pop("estimates")
219-
targets = plot_data.pop("targets")
220-
221-
# Determine widths to compute coverage for
222-
num_draws = estimates.shape[1]
223-
widths = np.arange(0, num_draws + 2) / (num_draws + 1)
224-
225-
# Compute empirical coverage with default parameters
226-
coverage_data = compute_empirical_coverage(
227-
estimates=estimates,
228-
targets=targets,
229-
widths=widths,
230-
prob=0.95,
231-
interval_type="central",
232-
)
233-
234-
# Plot coverage difference for each parameter
235-
for i, ax in enumerate(plot_data["axes"].flat):
236-
if i >= plot_data["num_variables"]:
237-
break
238-
239-
width_rep = coverage_data["width_represented"][:, i]
240-
coverage_est = coverage_data["coverage_estimates"][:, i]
241-
coverage_low = coverage_data["coverage_lower"][:, i]
242-
coverage_high = coverage_data["coverage_upper"][:, i]
243-
244-
# Compute differences
245-
diff_est = coverage_est - width_rep
246-
diff_low = coverage_low - width_rep
247-
diff_high = coverage_high - width_rep
248-
249-
# Plot confidence ribbon
250-
ax.fill_between(
251-
width_rep,
252-
diff_low,
253-
diff_high,
254-
color="grey",
255-
alpha=0.33,
256-
label="95% Credible Interval",
257-
)
258-
259-
# Plot ideal coverage difference line (y = 0)
260-
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
261-
262-
# Plot empirical coverage difference
263-
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
264-
265-
# Set axis limits
266-
ax.set_xlim(0, 1)
267-
268-
# Add legend to first subplot
269-
if i == 0:
270-
ax.legend(fontsize=tick_fontsize, loc="upper right")
114+
if difference:
115+
# Compute differences for coverage difference plot
116+
diff_est = coverage_est - width_rep
117+
diff_low = coverage_low - width_rep
118+
diff_high = coverage_high - width_rep
119+
120+
# Plot confidence ribbon
121+
ax.fill_between(
122+
width_rep,
123+
diff_low,
124+
diff_high,
125+
color="grey",
126+
alpha=0.33,
127+
label="95% Credible Interval",
128+
)
129+
130+
# Plot ideal coverage difference line (y = 0)
131+
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
132+
133+
# Plot empirical coverage difference
134+
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
135+
136+
# Set axis limits
137+
ax.set_xlim(0, 1)
138+
139+
# Add legend to first subplot
140+
if i == 0:
141+
ax.legend(fontsize=tick_fontsize, loc="upper right")
142+
else:
143+
# Plot confidence ribbon
144+
ax.fill_between(
145+
width_rep,
146+
coverage_low,
147+
coverage_high,
148+
color="grey",
149+
alpha=0.33,
150+
label="95% Credible Interval",
151+
)
152+
153+
# Plot ideal coverage line (y = x)
154+
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
155+
156+
# Plot empirical coverage
157+
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
158+
159+
# Set axis limits
160+
ax.set_xlim(0, 1)
161+
ax.set_ylim(0, 1)
162+
163+
# Add legend to first subplot
164+
if i == 0:
165+
ax.legend(fontsize=tick_fontsize, loc="upper left")
271166

272167
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
273168

274169
# Add labels, titles, and set font sizes
170+
ylabel = "Observed coverage difference" if difference else "Observed coverage"
275171
add_titles_and_labels(
276172
axes=plot_data["axes"],
277173
num_row=plot_data["num_row"],
278174
num_col=plot_data["num_col"],
279175
title=plot_data["variable_names"],
280176
xlabel="Central interval width",
281-
ylabel="Coverage difference",
177+
ylabel=ylabel,
282178
title_fontsize=title_fontsize,
283179
label_fontsize=label_fontsize,
284180
)

0 commit comments

Comments
 (0)