@@ -1140,10 +1140,12 @@ def plot_confusion_matrix(
11401140 pred_models ,
11411141 model_names = None ,
11421142 fig_size = (5 , 5 ),
1143+ label_fontsize = 16 ,
11431144 title_fontsize = 18 ,
1145+ value_fontsize = 10 ,
11441146 tick_fontsize = 12 ,
1145- xtick_rotation = None ,
1146- ytick_rotation = None ,
1147+ xtick_rotation = 0 ,
1148+ ytick_rotation = 90 ,
11471149 normalize = True ,
11481150 cmap = None ,
11491151 title = True ,
@@ -1160,13 +1162,17 @@ def plot_confusion_matrix(
11601162 The model names for nice plot titles. Inferred if None.
11611163 fig_size : tuple or None, optional, default: (5, 5)
11621164 The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
1165+ label_fontsize : int, optional, default: 16
1166+ The font size of the y-label and y-label texts
11631167 title_fontsize : int, optional, default: 18
11641168 The font size of the title text.
1169+ value_fontsize : int, optional, default: 10
1170+ The font size of the text annotations and the colorbar tick labels.
11651171 tick_fontsize : int, optional, default: 12
11661172 The font size of the axis label and model name texts.
1167- xtick_rotation: int, optional, default: None
1173+ xtick_rotation: int, optional, default: 0
11681174 Rotation of x-axis tick labels (helps with long model names).
1169- ytick_rotation: int, optional, default: None
1175+ ytick_rotation: int, optional, default: 90
11701176 Rotation of y-axis tick labels (helps with long model names).
11711177 normalize : bool, optional, default: True
11721178 A flag for normalization of the confusion matrix.
@@ -1201,8 +1207,10 @@ def plot_confusion_matrix(
12011207 # Initialize figure
12021208 fig , ax = plt .subplots (1 , 1 , figsize = fig_size )
12031209
1204- im = ax .imshow (cm , interpolation = "nearest" , cmap = cmap )
1205- ax .figure .colorbar (im , ax = ax , shrink = 0.7 )
1210+ im = ax .imshow (cm , interpolation = "nearest" , cmap = cmap , vmin = 0.0 , vmax = 1.0 )
1211+ cbar = ax .figure .colorbar (im , ax = ax , shrink = 0.75 )
1212+
1213+ cbar .ax .tick_params (labelsize = value_fontsize )
12061214
12071215 ax .set (xticks = np .arange (cm .shape [1 ]), yticks = np .arange (cm .shape [0 ]))
12081216 ax .set_xticklabels (model_names , fontsize = tick_fontsize )
@@ -1211,16 +1219,16 @@ def plot_confusion_matrix(
12111219 ax .set_yticklabels (model_names , fontsize = tick_fontsize )
12121220 if ytick_rotation :
12131221 plt .yticks (rotation = ytick_rotation )
1214- ax .set_xlabel ("Predicted model" , fontsize = tick_fontsize )
1215- ax .set_ylabel ("True model" , fontsize = tick_fontsize )
1222+ ax .set_xlabel ("Predicted model" , fontsize = label_fontsize )
1223+ ax .set_ylabel ("True model" , fontsize = label_fontsize )
12161224
12171225 # Loop over data dimensions and create text annotations
12181226 fmt = ".2f" if normalize else "d"
12191227 thresh = cm .max () / 2.0
12201228 for i in range (cm .shape [0 ]):
12211229 for j in range (cm .shape [1 ]):
12221230 ax .text (
1223- j , i , format (cm [i , j ], fmt ), ha = "center" , va = "center" , color = "white" if cm [i , j ] > thresh else "black"
1231+ j , i , format (cm [i , j ], fmt ), fontsize = value_fontsize , ha = "center" , va = "center" , color = "white" if cm [i , j ] > thresh else "black"
12241232 )
12251233 if title :
12261234 ax .set_title ("Confusion Matrix" , fontsize = title_fontsize )
0 commit comments