@@ -68,41 +68,42 @@ def mc_calibration(
6868
6969 # Gather plot data and metadata into a dictionary
7070 plot_data = prepare_plot_data (
71- estimates = pred_models ,
72- ground_truths = true_models ,
71+ targets = pred_models ,
72+ references = true_models ,
7373 variable_names = model_names ,
7474 num_col = num_col ,
7575 num_row = num_row ,
7676 figsize = figsize ,
77+ default_name = "M" ,
7778 )
7879
7980 # Compute calibration
8081 cal_errors , true_probs , pred_probs = expected_calibration_error (
81- plot_data ["ground_truths " ], plot_data ["estimates " ], num_bins
82+ plot_data ["references " ], plot_data ["targets " ], num_bins
8283 )
8384
8485 for j , ax in enumerate (plot_data ["axes" ].flat ):
8586 # Plot calibration curve
86- ax [ j ] .plot (pred_probs [j ], true_probs [j ], "o-" , color = color )
87+ ax .plot (pred_probs [j ], true_probs [j ], "o-" , color = color )
8788
8889 # Plot PMP distribution over bins
8990 uniform_bins = np .linspace (0.0 , 1.0 , num_bins + 1 )
90- norm_weights = np .ones_like (plot_data ["estimates " ]) / len (plot_data ["estimates " ])
91- ax [ j ] .hist (plot_data ["estimates " ][:, j ], bins = uniform_bins , weights = norm_weights [:, j ], color = "grey" , alpha = 0.3 )
91+ norm_weights = np .ones_like (plot_data ["targets " ]) / len (plot_data ["targets " ])
92+ ax .hist (plot_data ["targets " ][:, j ], bins = uniform_bins , weights = norm_weights [:, j ], color = "grey" , alpha = 0.3 )
9293
9394 # Plot AB line
94- ax [ j ] .plot ((0 , 1 ), (0 , 1 ), "--" , color = "black" , alpha = 0.9 )
95+ ax .plot ((0 , 1 ), (0 , 1 ), "--" , color = "black" , alpha = 0.9 )
9596
9697 # Tweak plot
97- ax [ j ] .set_xlim ([0 - epsilon , 1 + epsilon ])
98- ax [ j ] .set_ylim ([0 - epsilon , 1 + epsilon ])
99- ax [ j ] .set_xticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
100- ax [ j ] .set_yticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
98+ ax .set_xlim ([0 - epsilon , 1 + epsilon ])
99+ ax .set_ylim ([0 - epsilon , 1 + epsilon ])
100+ ax .set_xticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
101+ ax .set_yticks ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 , 1.0 ])
101102
102103 # Add ECE label
103104 add_metric (
104- ax [ j ] ,
105- metric_text = r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f} " ,
105+ ax ,
106+ metric_text = r"$\widehat{{\mathrm{{ECE}}}}$" ,
106107 metric_value = cal_errors [j ],
107108 metric_fontsize = metric_fontsize ,
108109 )
0 commit comments