@@ -1140,7 +1140,9 @@ def plot_confusion_matrix(
11401140 pred_models ,
11411141 model_names = None ,
11421142 fig_size = (5 , 5 ),
1143+ label_fontsize = 16 ,
11431144 title_fontsize = 18 ,
1145+ value_fontsize = 10 ,
11441146 tick_fontsize = 12 ,
11451147 xtick_rotation = None ,
11461148 ytick_rotation = None ,
@@ -1160,8 +1162,12 @@ def plot_confusion_matrix(
11601162 The model names for nice plot titles. Inferred if None.
11611163 fig_size : tuple or None, optional, default: (5, 5)
11621164 The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
1165+ label_fontsize : int, optional, default: 16
1166+ The font size of the y-label and y-label texts
11631167 title_fontsize : int, optional, default: 18
11641168 The font size of the title text.
1169+ value_fontsize : int, optional, default: 10
1170+ The font size of the text annotations and the colorbar tick labels.
11651171 tick_fontsize : int, optional, default: 12
11661172 The font size of the axis label and model name texts.
11671173 xtick_rotation: int, optional, default: None
@@ -1200,9 +1206,10 @@ def plot_confusion_matrix(
12001206
12011207 # Initialize figure
12021208 fig , ax = plt .subplots (1 , 1 , figsize = fig_size )
1203-
12041209 im = ax .imshow (cm , interpolation = "nearest" , cmap = cmap )
1205- ax .figure .colorbar (im , ax = ax , shrink = 0.7 )
1210+ cbar = ax .figure .colorbar (im , ax = ax , shrink = 0.75 )
1211+
1212+ cbar .ax .tick_params (labelsize = value_fontsize )
12061213
12071214 ax .set (xticks = np .arange (cm .shape [1 ]), yticks = np .arange (cm .shape [0 ]))
12081215 ax .set_xticklabels (model_names , fontsize = tick_fontsize )
@@ -1211,22 +1218,21 @@ def plot_confusion_matrix(
12111218 ax .set_yticklabels (model_names , fontsize = tick_fontsize )
12121219 if ytick_rotation :
12131220 plt .yticks (rotation = ytick_rotation )
1214- ax .set_xlabel ("Predicted model" , fontsize = tick_fontsize )
1215- ax .set_ylabel ("True model" , fontsize = tick_fontsize )
1221+ ax .set_xlabel ("Predicted model" , fontsize = label_fontsize )
1222+ ax .set_ylabel ("True model" , fontsize = label_fontsize )
12161223
12171224 # Loop over data dimensions and create text annotations
12181225 fmt = ".2f" if normalize else "d"
12191226 thresh = cm .max () / 2.0
12201227 for i in range (cm .shape [0 ]):
12211228 for j in range (cm .shape [1 ]):
12221229 ax .text (
1223- j , i , format (cm [i , j ], fmt ), ha = "center" , va = "center" , color = "white" if cm [i , j ] > thresh else "black"
1230+ j , i , format (cm [i , j ], fmt ), fontsize = value_fontsize , ha = "center" , va = "center" , color = "white" if cm [i , j ] > thresh else "black"
12241231 )
12251232 if title :
12261233 ax .set_title ("Confusion Matrix" , fontsize = title_fontsize )
12271234 return fig
12281235
1229-
12301236def plot_mmd_hypothesis_test (
12311237 mmd_null ,
12321238 mmd_observed = None ,
0 commit comments