Skip to content

Commit d753b05

Browse files
Merge pull request #98 from elseml/Development
Omit unneccessary labels in plot_calibration_curves
2 parents f46374d + 6a5600b commit d753b05

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

bayesflow/diagnostics.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,8 @@ def plot_calibration_curves(
10601060
# Determine n_subplots dynamically
10611061
n_row = int(np.ceil(num_models / 6))
10621062
n_col = int(np.ceil(num_models / n_row))
1063+
1064+
# Compute calibration
10631065
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
10641066

10651067
# Initialize figure
@@ -1094,8 +1096,6 @@ def plot_calibration_curves(
10941096
ax[j].spines["top"].set_visible(False)
10951097
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
10961098
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
1097-
ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize)
1098-
ax[j].set_ylabel("True probability", fontsize=label_fontsize)
10991099
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
11001100
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
11011101
ax[j].grid(alpha=0.5)
@@ -1111,6 +1111,18 @@ def plot_calibration_curves(
11111111
size=legend_fontsize,
11121112
)
11131113

1114+
# Only add x-labels to the bottom row
1115+
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
1116+
for _ax in bottom_row:
1117+
_ax.set_xlabel("Predicted probability", fontsize=label_fontsize)
1118+
1119+
# Only add y-labels to left-most row
1120+
if n_row == 1: # if there is only one row, the ax array is 1D
1121+
ax[0].set_ylabel("True probability", fontsize=label_fontsize)
1122+
else: # if there is more than one row, the ax array is 2D
1123+
for _ax in axarr[:, 0]:
1124+
_ax.set_ylabel("True probability", fontsize=label_fontsize)
1125+
11141126
fig.tight_layout()
11151127
return fig
11161128

0 commit comments

Comments
 (0)