@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
851851
852852    LOG_INF (" %s : calculating hellaswag score over selected tasks.\n "  , __func__);
853853
854-     LOG (" \n task\t acc_norm\n "  );
854+     LOG (" \n task\t acc_norm\t  95%% confidence interval \ n"  );
855855
856856    double  acc = 0 .0f ;
857857
@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
985985                acc += 1.0 ;
986986            }
987987
988-             //  Print the accumulated accuracy mean x 100
989-             LOG (" %zu\t %.8lf\n "  , i + 1 , acc/double (i + 1 )*100.0 );
988+             double  freq = acc / double (i + 1 );
989+ 
990+             const  double  za = 1.95996398454 ;
991+ 
992+             //  // Wald normal approx
993+             //  double conf =za*sqrt(freq*(1-freq)/double(i + 1));
994+             //  LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
995+ 
996+             //  Wilson score interval, more accurate
997+             double  z   = za * za / double (i + 1 );
998+             double  cnf = z * sqrt (double (i + 1 ) * (4.0  * freq * (1  - freq) + z)) / (za + za);
999+             double  a   = (freq + z * 0.5  - cnf) / (1.0  + z);
1000+             double  b   = (freq + z * 0.5  + cnf) / (1.0  + z);
1001+ 
1002+             //  Print the accumulated accuracy mean x 100 and confidence interval
1003+             LOG (" %zu\t %3.8lf%%\t [%3.4lf%%, %3.4lf%%]\n "  , i + 1 , freq * 100.0 , a * 100.0 , b * 100.0 );
9901004        }
9911005
9921006        i0 = i1 - 1 ;
0 commit comments