@@ -1060,6 +1060,8 @@ def plot_calibration_curves(
10601060 # Determine n_subplots dynamically
10611061 n_row = int (np .ceil (num_models / 6 ))
10621062 n_col = int (np .ceil (num_models / n_row ))
1063+
1064+ # Compute calibration
10631065 cal_errs , probs_true , probs_pred = expected_calibration_error (true_models , pred_models , num_bins )
10641066
10651067 # Initialize figure
@@ -1094,8 +1096,6 @@ def plot_calibration_curves(
10941096 ax [j ].spines ["top" ].set_visible (False )
10951097 ax [j ].set_xlim ([0 - epsilon , 1 + epsilon ])
10961098 ax [j ].set_ylim ([0 - epsilon , 1 + epsilon ])
1097- ax [j ].set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1098- ax [j ].set_ylabel ("True probability" , fontsize = label_fontsize )
10991099 ax [j ].set_xticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
11001100 ax [j ].set_yticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
11011101 ax [j ].grid (alpha = 0.5 )
@@ -1111,6 +1111,18 @@ def plot_calibration_curves(
11111111 size = legend_fontsize ,
11121112 )
11131113
1114+ # Only add x-labels to the bottom row
1115+ bottom_row = axarr if n_row == 1 else axarr [0 ] if n_col == 1 else axarr [n_row - 1 , :]
1116+ for _ax in bottom_row :
1117+ _ax .set_xlabel ("Predicted probability" , fontsize = label_fontsize )
1118+
1119+ # Only add y-labels to left-most row
1120+ if n_row == 1 : # if there is only one row, the ax array is 1D
1121+ ax [0 ].set_ylabel ("True probability" , fontsize = label_fontsize )
1122+ else : # if there is more than one row, the ax array is 2D
1123+ for _ax in axarr [:, 0 ]:
1124+ _ax .set_ylabel ("True probability" , fontsize = label_fontsize )
1125+
11141126 fig .tight_layout ()
11151127 return fig
11161128
0 commit comments