@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
851851
852852 LOG_INF (" %s : calculating hellaswag score over selected tasks.\n " , __func__);
853853
854- LOG (" \n task\t acc_norm\n " );
854+ LOG (" \n task\t acc_norm\t 95%% confidence interval \ n" );
855855
856856 double acc = 0 .0f ;
857857
@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
985985 acc += 1.0 ;
986986 }
987987
988- // Print the accumulated accuracy mean x 100
989- LOG (" %zu\t %.8lf\n " , i + 1 , acc/double (i + 1 )*100.0 );
988+ double freq = acc / double (i + 1 );
989+
990+ const double za = 1.95996398454 ;
991+
992+ // // Wald normal approx
993+ // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
994+ // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
995+
996+ // Wilson score interval, more accurate
997+ double z = za * za / double (i + 1 );
998+ double cnf = z * sqrt (double (i + 1 ) * (4.0 * freq * (1 - freq) + z)) / (za + za);
999+ double a = (freq + z * 0.5 - cnf) / (1.0 + z);
1000+ double b = (freq + z * 0.5 + cnf) / (1.0 + z);
1001+
1002+ // Print the accumulated accuracy mean x 100 and confidence interval
1003+ LOG (" %zu\t %3.8lf%%\t [%3.4lf%%, %3.4lf%%]\n " , i + 1 , freq * 100.0 , a * 100.0 , b * 100.0 );
9901004 }
9911005
9921006 i0 = i1 - 1 ;
0 commit comments