Skip to content

Commit 3c29d9e

Browse files
committed
Add coverage to default diagnostics
1 parent 98b9fc5 commit 3c29d9e

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

bayesflow/diagnostics/plots/coverage.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ def coverage(
1414
variable_names: Sequence[str] = None,
1515
figsize: Sequence[int] = None,
1616
label_fontsize: int = 16,
17+
legend_fontsize: int = 14,
1718
title_fontsize: int = 18,
1819
tick_fontsize: int = 12,
20+
legend_location: str = "upper right",
1921
color: str = "#132a70",
2022
num_col: int = None,
2123
num_row: int = None,
@@ -52,6 +54,8 @@ def coverage(
5254
The figure size passed to the matplotlib constructor. Inferred if None.
5355
label_fontsize : int, optional, default: 16
5456
The font size of the y-label and x-label text
57+
legend_fontsize : int, optional, default: 14
58+
The font size of the legend text
5559
title_fontsize : int, optional, default: 18
5660
The font size of the title text
5761
tick_fontsize : int, optional, default: 12
@@ -133,12 +137,6 @@ def coverage(
133137
# Plot empirical coverage difference
134138
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")
135139

136-
# Set axis limits
137-
ax.set_xlim(0, 1)
138-
139-
# Add legend to first subplot
140-
if i == 0:
141-
ax.legend(fontsize=tick_fontsize, loc="upper right")
142140
else:
143141
# Plot confidence ribbon
144142
ax.fill_between(
@@ -156,13 +154,13 @@ def coverage(
156154
# Plot empirical coverage
157155
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")
158156

159-
# Set axis limits
160-
ax.set_xlim(0, 1)
161-
ax.set_ylim(0, 1)
157+
# Set axis limits
158+
ax.set_xlim(0, 1)
159+
ax.set_ylim(0, 1)
162160

163-
# Add legend to first subplot
164-
if i == 0:
165-
ax.legend(fontsize=tick_fontsize, loc="upper left")
161+
# Add legend to first subplot
162+
if i == 0:
163+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
166164

167165
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
168166

bayesflow/workflows/basic_workflow.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,12 @@ def plot_default_diagnostics(
349349
- Loss history (if training history is available).
350350
- Parameter recovery plots.
351351
- Calibration ECDF plots.
352+
- Coverage plots.
352353
- Z-score contraction plots.
353354
355+
Caution: For models with many parameters, plotting all marginal diagnostics becomes unwieldy. Consider
356+
providing `variables_keyes` for visualizing the diagnostics for subsets of the parameter space.
357+
354358
Parameters
355359
----------
356360
test_data : Mapping[str, np.ndarray] or int
@@ -400,6 +404,7 @@ def plot_default_diagnostics(
400404
plot_fns = {
401405
"recovery": bf_plots.recovery,
402406
"calibration_ecdf": bf_plots.calibration_ecdf,
407+
"coverage": bf_plots.coverage,
403408
"z_score_contraction": bf_plots.z_score_contraction,
404409
}
405410

@@ -499,9 +504,10 @@ def compute_default_diagnostics(
499504
"""
500505
Computes default diagnostic metrics to evaluate the quality of inference. The function computes several
501506
diagnostic metrics, including:
502-
- Root Mean Squared Error (RMSE)
503-
- Posterior contraction
504-
- Calibration error
507+
- (Normalized) Root Mean Squared Error ((N)RMSE): summarizes the recovery plots
508+
- Log-gamma statistic - summarizes the ECDF calibration plots
509+
- Expected Calibration Error (ECE) - summarizes the coverage plots
510+
- Posterior contraction - partially summarizes the contraction plots
505511
506512
Parameters
507513
----------
@@ -553,12 +559,12 @@ def compute_default_diagnostics(
553559
**kwargs.get("root_mean_squared_error_kwargs", {}),
554560
)
555561

556-
contraction = bf_metrics.posterior_contraction(
562+
log_gamma = bf_metrics.calibration_log_gamma(
557563
estimates=samples,
558564
targets=test_data,
559565
variable_keys=variable_keys,
560566
variable_names=variable_names,
561-
**kwargs.get("posterior_contraction_kwargs", {}),
567+
**kwargs.get("log_gamma_kwargs", {}),
562568
)
563569

564570
calibration_errors = bf_metrics.calibration_error(
@@ -569,17 +575,26 @@ def compute_default_diagnostics(
569575
**kwargs.get("calibration_error_kwargs", {}),
570576
)
571577

578+
contraction = bf_metrics.posterior_contraction(
579+
estimates=samples,
580+
targets=test_data,
581+
variable_keys=variable_keys,
582+
variable_names=variable_names,
583+
**kwargs.get("posterior_contraction_kwargs", {}),
584+
)
585+
572586
if as_data_frame:
573587
metrics = pd.DataFrame(
574588
{
575589
root_mean_squared_error["metric_name"]: root_mean_squared_error["values"],
576-
contraction["metric_name"]: contraction["values"],
590+
log_gamma["metric_name"]: log_gamma["values"],
577591
calibration_errors["metric_name"]: calibration_errors["values"],
592+
contraction["metric_name"]: contraction["values"],
578593
},
579594
index=variable_keys or root_mean_squared_error["variable_names"],
580595
).T
581596
else:
582-
metrics = (root_mean_squared_error, contraction, calibration_errors)
597+
metrics = (root_mean_squared_error, log_gamma, calibration_errors, contraction)
583598

584599
return metrics
585600

0 commit comments

Comments
 (0)