Skip to content

Commit b2a9ded

Browse files
Improve default diagnostics (#584)
* Add coverage to default diagnostics * Improve diagnostics * Adapt test * Account for new metric * Fix doc [skip ci]
1 parent 98b9fc5 commit b2a9ded

File tree

10 files changed

+405
-158
lines changed

10 files changed

+405
-158
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def calibration_ecdf(
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
1717
test_quantities: dict[str, Callable] = None,
18-
difference: bool = False,
18+
difference: bool = True,
1919
stacked: bool = False,
2020
rank_type: str | np.ndarray = "fractional",
2121
figsize: Sequence[float] = None,
2222
label_fontsize: int = 16,
2323
legend_fontsize: int = 14,
24-
legend_location: str = "upper right",
24+
legend_location: str = "lower right",
2525
title_fontsize: int = 18,
2626
tick_fontsize: int = 12,
2727
rank_ecdf_color: str = "#132a70",
@@ -59,7 +59,7 @@ def calibration_ecdf(
5959
The posterior draws obtained from n_data_sets
6060
targets : np.ndarray of shape (n_data_sets, n_params)
6161
The prior draws obtained for generating n_data_sets
62-
difference : bool, optional, default: False
62+
difference : bool, optional, default: True
6363
If `True`, plots the ECDF difference.
6464
Enables a more dynamic visualization range.
6565
stacked : bool, optional, default: False
@@ -98,7 +98,9 @@ def calibration_ecdf(
9898
label_fontsize : int, optional, default: 16
9999
The font size of the y-label and y-label texts
100100
legend_fontsize : int, optional, default: 14
101-
The font size of the legend text
101+
The font size of the legend text.
102+
legend_location : str, optional, default: 'lower right
103+
The location of the legend.
102104
title_fontsize : int, optional, default: 18
103105
The font size of the title text.
104106
Only relevant if `stacked=False`
@@ -211,11 +213,13 @@ def calibration_ecdf(
211213
else:
212214
titles = ["Stacked ECDFs"]
213215

214-
for ax, title in zip(plot_data["axes"].flat, titles):
216+
for i, (ax, title) in enumerate(zip(plot_data["axes"].flat, titles)):
215217
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
216-
ax.legend(fontsize=legend_fontsize, loc=legend_location)
217218
ax.set_title(title, fontsize=title_fontsize)
218219

220+
if i == 0:
221+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
222+
219223
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
220224

221225
add_titles_and_labels(

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ def calibration_ecdf_from_quantiles(
1414
quantiles_key: str = "quantiles",
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
17-
difference: bool = False,
17+
difference: bool = True,
1818
stacked: bool = False,
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22-
legend_location: str = "upper right",
22+
legend_location: str = "lower right",
2323
title_fontsize: int = 18,
2424
tick_fontsize: int = 12,
2525
rank_ecdf_color: str = "#132a70",
@@ -69,7 +69,7 @@ def calibration_ecdf_from_quantiles(
6969
variable_names : list or None, optional, default: None
7070
The parameter names for nice plot titles.
7171
Inferred if None. Only relevant if `stacked=False`.
72-
difference : bool, optional, default: False
72+
difference : bool, optional, default: True
7373
If `True`, plots the ECDF difference.
7474
Enables a more dynamic visualization range.
7575
stacked : bool, optional, default: False
@@ -82,7 +82,9 @@ def calibration_ecdf_from_quantiles(
8282
label_fontsize : int, optional, default: 16
8383
The font size of the y-label and y-label texts
8484
legend_fontsize : int, optional, default: 14
85-
The font size of the legend text
85+
The font size of the legend text.
86+
legend_location : str, optional, default: 'lower right
87+
The location of the legend.
8688
title_fontsize : int, optional, default: 18
8789
The font size of the title text.
8890
Only relevant if `stacked=False`

bayesflow/diagnostics/plots/coverage.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ def coverage(
1414
variable_names: Sequence[str] = None,
1515
figsize: Sequence[int] = None,
1616
label_fontsize: int = 16,
17+
legend_fontsize: int = 14,
1718
title_fontsize: int = 18,
1819
tick_fontsize: int = 12,
20+
legend_location: str = "lower right",
1921
color: str = "#132a70",
2022
num_col: int = None,
2123
num_row: int = None,
@@ -39,7 +41,7 @@ def coverage(
3941
The posterior draws obtained from num_datasets
4042
targets : np.ndarray of shape (num_datasets, num_params)
4143
The true parameter values used for generating num_datasets
42-
difference : bool, optional, default: False
44+
difference : bool, optional, default: True
4345
If True, plots the difference between empirical coverage and ideal coverage
4446
(coverage - width), making deviations from ideal calibration more visible.
4547
If False, plots the standard coverage plot.
@@ -52,10 +54,14 @@ def coverage(
5254
The figure size passed to the matplotlib constructor. Inferred if None.
5355
label_fontsize : int, optional, default: 16
5456
The font size of the y-label and x-label text
57+
legend_fontsize : int, optional, default: 14
58+
The font size of the legend text
5559
title_fontsize : int, optional, default: 18
5660
The font size of the title text
5761
tick_fontsize : int, optional, default: 12
5862
The font size of the axis ticklabels
63+
legend_location : str, optional, default: 'upper right
64+
The location of the legend.
5965
color : str, optional, default: '#132a70'
6066
The color for the coverage line
6167
num_row : int, optional, default: None
@@ -128,17 +134,11 @@ def coverage(
128134
)
129135

130136
# Plot ideal coverage difference line (y = 0)
131-
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
137+
ax.axhline(y=0, color="black", linestyle="dashed", label="Ideal Coverage")
132138

133139
# Plot empirical coverage difference
134140
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
135141

136-
# Set axis limits
137-
ax.set_xlim(0, 1)
138-
139-
# Add legend to first subplot
140-
if i == 0:
141-
ax.legend(fontsize=tick_fontsize, loc="upper right")
142142
else:
143143
# Plot confidence ribbon
144144
ax.fill_between(
@@ -151,23 +151,19 @@ def coverage(
151151
)
152152

153153
# Plot ideal coverage line (y = x)
154-
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
154+
ax.plot([0, 1], [0, 1], color="black", linestyle="dashed", label="Ideal Coverage")
155155

156156
# Plot empirical coverage
157157
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
158158

159-
# Set axis limits
160-
ax.set_xlim(0, 1)
161-
ax.set_ylim(0, 1)
162-
163-
# Add legend to first subplot
164-
if i == 0:
165-
ax.legend(fontsize=tick_fontsize, loc="upper left")
159+
# Add legend to first subplot
160+
if i == 0:
161+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
166162

167163
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
168164

169165
# Add labels, titles, and set font sizes
170-
ylabel = "Observed coverage difference" if difference else "Observed coverage"
166+
ylabel = "Empirical coverage difference" if difference else "Empirical coverage"
171167
add_titles_and_labels(
172168
axes=plot_data["axes"],
173169
num_row=plot_data["num_row"],

bayesflow/diagnostics/plots/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def loss(
168168
num_col=1,
169169
title=["Loss Trajectory"],
170170
xlabel="Training epoch #",
171-
ylabel="Value",
171+
ylabel="Loss",
172172
title_fontsize=title_fontsize,
173173
label_fontsize=label_fontsize,
174174
)

bayesflow/diagnostics/plots/recovery.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def recovery(
8787
The number of rows for the subplots. Dynamically determined if None.
8888
num_col : int, optional, default: None
8989
The number of columns for the subplots. Dynamically determined if None.
90-
xlabel :
91-
ylabel :
90+
xlabel : str, optional, default: "Ground truth"
91+
The label shown on the x-axis.
92+
ylabel : str, optional, default: "Estimate"
93+
The label shown on the y-axis.
9294
markersize : float, optional, default: None
9395
The marker size in points.
9496

bayesflow/workflows/basic_workflow.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,12 @@ def plot_default_diagnostics(
349349
- Loss history (if training history is available).
350350
- Parameter recovery plots.
351351
- Calibration ECDF plots.
352+
- Coverage plots.
352353
- Z-score contraction plots.
353354
355+
Caution: For models with many parameters, plotting all marginal diagnostics becomes unwieldy. Consider
356+
providing `variables_keys` for visualizing the diagnostics for subsets of the parameter space.
357+
354358
Parameters
355359
----------
356360
test_data : Mapping[str, np.ndarray] or int
@@ -400,6 +404,7 @@ def plot_default_diagnostics(
400404
plot_fns = {
401405
"recovery": bf_plots.recovery,
402406
"calibration_ecdf": bf_plots.calibration_ecdf,
407+
"coverage": bf_plots.coverage,
403408
"z_score_contraction": bf_plots.z_score_contraction,
404409
}
405410

@@ -499,9 +504,10 @@ def compute_default_diagnostics(
499504
"""
500505
Computes default diagnostic metrics to evaluate the quality of inference. The function computes several
501506
diagnostic metrics, including:
502-
- Root Mean Squared Error (RMSE)
503-
- Posterior contraction
504-
- Calibration error
507+
- (Normalized) Root Mean Squared Error ((N)RMSE): summarizes the recovery plots
508+
- Log-gamma statistic - summarizes the ECDF calibration plots
509+
- Expected Calibration Error (ECE) - summarizes the coverage plots
510+
- Posterior contraction - partially summarizes the contraction plots
505511
506512
Parameters
507513
----------
@@ -553,12 +559,12 @@ def compute_default_diagnostics(
553559
**kwargs.get("root_mean_squared_error_kwargs", {}),
554560
)
555561

556-
contraction = bf_metrics.posterior_contraction(
562+
log_gamma = bf_metrics.calibration_log_gamma(
557563
estimates=samples,
558564
targets=test_data,
559565
variable_keys=variable_keys,
560566
variable_names=variable_names,
561-
**kwargs.get("posterior_contraction_kwargs", {}),
567+
**kwargs.get("log_gamma_kwargs", {}),
562568
)
563569

564570
calibration_errors = bf_metrics.calibration_error(
@@ -569,17 +575,26 @@ def compute_default_diagnostics(
569575
**kwargs.get("calibration_error_kwargs", {}),
570576
)
571577

578+
contraction = bf_metrics.posterior_contraction(
579+
estimates=samples,
580+
targets=test_data,
581+
variable_keys=variable_keys,
582+
variable_names=variable_names,
583+
**kwargs.get("posterior_contraction_kwargs", {}),
584+
)
585+
572586
if as_data_frame:
573587
metrics = pd.DataFrame(
574588
{
575589
root_mean_squared_error["metric_name"]: root_mean_squared_error["values"],
576-
contraction["metric_name"]: contraction["values"],
590+
log_gamma["metric_name"]: log_gamma["values"],
577591
calibration_errors["metric_name"]: calibration_errors["values"],
592+
contraction["metric_name"]: contraction["values"],
578593
},
579594
index=variable_keys or root_mean_squared_error["variable_names"],
580595
).T
581596
else:
582-
metrics = (root_mean_squared_error, contraction, calibration_errors)
597+
metrics = (root_mean_squared_error, log_gamma, calibration_errors, contraction)
583598

584599
return metrics
585600

examples/Multimodal_Data.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@
524524
},
525525
{
526526
"cell_type": "code",
527-
"execution_count": 16,
527+
"execution_count": null,
528528
"id": "2415fd0b-f5d6-4fc9-83d7-8952e6270186",
529529
"metadata": {},
530530
"outputs": [

0 commit comments

Comments
 (0)