Skip to content

Commit eb911c7

Browse files
committed
Aesthetisize plot_calibration_curves
1 parent 4061488 commit eb911c7

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

bayesflow/diagnostics.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,7 @@ def plot_calibration_curves(
10151015
legend_fontsize=14,
10161016
title_fontsize=18,
10171017
tick_fontsize=12,
1018+
epsilon=0.02,
10181019
fig_size=None,
10191020
color="#8f2727",
10201021
):
@@ -1040,6 +1041,8 @@ def plot_calibration_curves(
10401041
The font size of the title text. Only relevant if `stacked=False`
10411042
tick_fontsize : int, optional, default: 12
10421043
The font size of the axis ticklabels
1044+
epsilon : float, optional, default: 0.02
1045+
A small amount to pad the [0, 1]-bounded axes from both side.
10431046
fig_size : tuple or None, optional, default: None
10441047
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
10451048
color : str, optional, default: '#8f2727'
@@ -1073,26 +1076,31 @@ def plot_calibration_curves(
10731076
ax = axarr
10741077
for j in range(num_models):
10751078
# Plot calibration curve
1076-
ax[j].plot(probs_pred[j], probs_true[j], color=color)
1077-
1078-
# Plot AB line
1079-
ax[j].plot(ax[j].get_xlim(), ax[j].get_xlim(), "--", color="darkgrey")
1079+
ax[j].plot(probs_pred[j], probs_true[j], "o-", color=color)
10801080

10811081
# Plot PMP distribution over bins
10821082
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
10831083
norm_weights = np.ones_like(pred_models) / len(pred_models)
10841084
ax[j].hist(pred_models[:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
10851085

1086+
# Plot AB line
1087+
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
1088+
10861089
# Tweak plot
1090+
ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
1091+
ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
1092+
ax[j].set_title(model_names[j], fontsize=title_fontsize)
10871093
ax[j].spines["right"].set_visible(False)
10881094
ax[j].spines["top"].set_visible(False)
1089-
ax[j].set_xlim([0, 1])
1090-
ax[j].set_ylim([0, 1])
1095+
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
1096+
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
10911097
ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize)
10921098
ax[j].set_ylabel("True probability", fontsize=label_fontsize)
1093-
ax[j].set_xticks([0.2, 0.4, 0.6, 0.8, 1.0])
1094-
ax[j].set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
1099+
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
1100+
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
10951101
ax[j].grid(alpha=0.5)
1102+
1103+
# Add ECE label
10961104
ax[j].text(
10971105
0.1,
10981106
0.9,
@@ -1102,11 +1110,7 @@ def plot_calibration_curves(
11021110
transform=ax[j].transAxes,
11031111
size=legend_fontsize,
11041112
)
1105-
ax[j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
1106-
ax[j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
11071113

1108-
# Set title
1109-
ax[j].set_title(model_names[j], fontsize=title_fontsize)
11101114
fig.tight_layout()
11111115
return fig
11121116

0 commit comments

Comments
 (0)