Skip to content

Commit ccad76e

Browse files
Merge pull request #106 from LuSchumacher/Development
add label_fontsize and value_fontsize
2 parents 7214f78 + 41ee9a7 commit ccad76e

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

bayesflow/diagnostics.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,9 @@ def plot_confusion_matrix(
11401140
pred_models,
11411141
model_names=None,
11421142
fig_size=(5, 5),
1143+
label_fontsize=16,
11431144
title_fontsize=18,
1145+
value_fontsize=10,
11441146
tick_fontsize=12,
11451147
xtick_rotation=None,
11461148
ytick_rotation=None,
@@ -1160,8 +1162,12 @@ def plot_confusion_matrix(
11601162
The model names for nice plot titles. Inferred if None.
11611163
fig_size : tuple or None, optional, default: (5, 5)
11621164
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
1165+
label_fontsize : int, optional, default: 16
1166+
The font size of the y-label and y-label texts
11631167
title_fontsize : int, optional, default: 18
11641168
The font size of the title text.
1169+
value_fontsize : int, optional, default: 10
1170+
The font size of the text annotations and the colorbar tick labels.
11651171
tick_fontsize : int, optional, default: 12
11661172
The font size of the axis label and model name texts.
11671173
xtick_rotation: int, optional, default: None
@@ -1200,9 +1206,10 @@ def plot_confusion_matrix(
12001206

12011207
# Initialize figure
12021208
fig, ax = plt.subplots(1, 1, figsize=fig_size)
1203-
12041209
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
1205-
ax.figure.colorbar(im, ax=ax, shrink=0.7)
1210+
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)
1211+
1212+
cbar.ax.tick_params(labelsize=value_fontsize)
12061213

12071214
ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
12081215
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
@@ -1211,22 +1218,21 @@ def plot_confusion_matrix(
12111218
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
12121219
if ytick_rotation:
12131220
plt.yticks(rotation=ytick_rotation)
1214-
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
1215-
ax.set_ylabel("True model", fontsize=tick_fontsize)
1221+
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
1222+
ax.set_ylabel("True model", fontsize=label_fontsize)
12161223

12171224
# Loop over data dimensions and create text annotations
12181225
fmt = ".2f" if normalize else "d"
12191226
thresh = cm.max() / 2.0
12201227
for i in range(cm.shape[0]):
12211228
for j in range(cm.shape[1]):
12221229
ax.text(
1223-
j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
1230+
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
12241231
)
12251232
if title:
12261233
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
12271234
return fig
12281235

1229-
12301236
def plot_mmd_hypothesis_test(
12311237
mmd_null,
12321238
mmd_observed=None,

0 commit comments

Comments
 (0)