@@ -1015,6 +1015,7 @@ def plot_calibration_curves(
10151015 legend_fontsize = 14 ,
10161016 title_fontsize = 18 ,
10171017 tick_fontsize = 12 ,
1018+ epsilon = 0.02 ,
10181019 fig_size = None ,
10191020 color = "#8f2727" ,
10201021):
@@ -1040,6 +1041,8 @@ def plot_calibration_curves(
10401041 The font size of the title text. Only relevant if `stacked=False`
10411042 tick_fontsize : int, optional, default: 12
10421043 The font size of the axis ticklabels
1044+ epsilon : float, optional, default: 0.02
1045+ A small amount to pad the [0, 1]-bounded axes from both side.
10431046 fig_size : tuple or None, optional, default: None
10441047 The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
10451048 color : str, optional, default: '#8f2727'
@@ -1073,26 +1076,31 @@ def plot_calibration_curves(
10731076 ax = axarr
10741077 for j in range (num_models ):
10751078 # Plot calibration curve
1076- ax [j ].plot (probs_pred [j ], probs_true [j ], color = color )
1077-
1078- # Plot AB line
1079- ax [j ].plot (ax [j ].get_xlim (), ax [j ].get_xlim (), "--" , color = "darkgrey" )
1079+ ax [j ].plot (probs_pred [j ], probs_true [j ], "o-" , color = color )
10801080
10811081 # Plot PMP distribution over bins
10821082 uniform_bins = np .linspace (0.0 , 1.0 , num_bins + 1 )
10831083 norm_weights = np .ones_like (pred_models ) / len (pred_models )
10841084 ax [j ].hist (pred_models [:, j ], bins = uniform_bins , weights = norm_weights [:, j ], color = "grey" , alpha = 0.3 )
10851085
1086+ # Plot AB line
1087+ ax [j ].plot ((0 , 1 ), (0 , 1 ), "--" , color = "black" , alpha = 0.9 )
1088+
10861089 # Tweak plot
1090+ ax [j ].tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
1091+ ax [j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
1092+ ax [j ].set_title (model_names [j ], fontsize = title_fontsize )
10871093 ax [j ].spines ["right" ].set_visible (False )
10881094 ax [j ].spines ["top" ].set_visible (False )
1089- ax [j ].set_xlim ([0 , 1 ])
1090- ax [j ].set_ylim ([0 , 1 ])
1095+ ax [j ].set_xlim ([0 - epsilon , 1 + epsilon ])
1096+ ax [j ].set_ylim ([0 - epsilon , 1 + epsilon ])
10911097 ax [j ].set_xlabel ("Predicted probability" , fontsize = label_fontsize )
10921098 ax [j ].set_ylabel ("True probability" , fontsize = label_fontsize )
1093- ax [j ].set_xticks ([0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
1094- ax [j ].set_yticks ([0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
1099+ ax [j ].set_xticks ([0.0 , 0. 2 , 0.4 , 0.6 , 0.8 , 1.0 ])
1100+ ax [j ].set_yticks ([0.0 , 0. 2 , 0.4 , 0.6 , 0.8 , 1.0 ])
10951101 ax [j ].grid (alpha = 0.5 )
1102+
1103+ # Add ECE label
10961104 ax [j ].text (
10971105 0.1 ,
10981106 0.9 ,
@@ -1102,11 +1110,7 @@ def plot_calibration_curves(
11021110 transform = ax [j ].transAxes ,
11031111 size = legend_fontsize ,
11041112 )
1105- ax [j ].tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
1106- ax [j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
11071113
1108- # Set title
1109- ax [j ].set_title (model_names [j ], fontsize = title_fontsize )
11101114 fig .tight_layout ()
11111115 return fig
11121116
0 commit comments