Skip to content

Commit b2b8305

Browse files
committed
add fontsize arguments
1 parent 7214f78 commit b2b8305

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

bayesflow/diagnostics.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,10 +1140,12 @@ def plot_confusion_matrix(
11401140
pred_models,
11411141
model_names=None,
11421142
fig_size=(5, 5),
1143+
label_fontsize=16,
11431144
title_fontsize=18,
1145+
value_fontsize=10,
11441146
tick_fontsize=12,
1145-
xtick_rotation=None,
1146-
ytick_rotation=None,
1147+
xtick_rotation=0,
1148+
ytick_rotation=90,
11471149
normalize=True,
11481150
cmap=None,
11491151
title=True,
@@ -1160,13 +1162,17 @@ def plot_confusion_matrix(
11601162
The model names for nice plot titles. Inferred if None.
11611163
fig_size : tuple or None, optional, default: (5, 5)
11621164
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
1165+
label_fontsize : int, optional, default: 16
1166+
The font size of the y-label and y-label texts
11631167
title_fontsize : int, optional, default: 18
11641168
The font size of the title text.
1169+
value_fontsize : int, optional, default: 10
1170+
The font size of the text annotations and the colorbar tick labels.
11651171
tick_fontsize : int, optional, default: 12
11661172
The font size of the axis label and model name texts.
1167-
xtick_rotation: int, optional, default: None
1173+
xtick_rotation: int, optional, default: 0
11681174
Rotation of x-axis tick labels (helps with long model names).
1169-
ytick_rotation: int, optional, default: None
1175+
ytick_rotation: int, optional, default: 90
11701176
Rotation of y-axis tick labels (helps with long model names).
11711177
normalize : bool, optional, default: True
11721178
A flag for normalization of the confusion matrix.
@@ -1201,8 +1207,10 @@ def plot_confusion_matrix(
12011207
# Initialize figure
12021208
fig, ax = plt.subplots(1, 1, figsize=fig_size)
12031209

1204-
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
1205-
ax.figure.colorbar(im, ax=ax, shrink=0.7)
1210+
im = ax.imshow(cm, interpolation="nearest", cmap=cmap, vmin=0.0, vmax=1.0)
1211+
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)
1212+
1213+
cbar.ax.tick_params(labelsize=value_fontsize)
12061214

12071215
ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
12081216
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
@@ -1211,16 +1219,16 @@ def plot_confusion_matrix(
12111219
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
12121220
if ytick_rotation:
12131221
plt.yticks(rotation=ytick_rotation)
1214-
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
1215-
ax.set_ylabel("True model", fontsize=tick_fontsize)
1222+
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
1223+
ax.set_ylabel("True model", fontsize=label_fontsize)
12161224

12171225
# Loop over data dimensions and create text annotations
12181226
fmt = ".2f" if normalize else "d"
12191227
thresh = cm.max() / 2.0
12201228
for i in range(cm.shape[0]):
12211229
for j in range(cm.shape[1]):
12221230
ax.text(
1223-
j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
1231+
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
12241232
)
12251233
if title:
12261234
ax.set_title("Confusion Matrix", fontsize=title_fontsize)

0 commit comments

Comments
 (0)