@@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
17571757 // find max logit and calculate mean
17581758 float max = cur_p->data [0 ].logit ;
17591759 float logits_sum = 0 ;
1760+ size_t valid_count = 0 ;
17601761 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1761- if (cur_p->data [i].logit > max) {
1762- max = cur_p->data [i].logit ;
1762+ // Only count non-negative infinity values
1763+ if (cur_p->data [i].logit != -INFINITY) {
1764+ if (cur_p->data [i].logit > max) {
1765+ max = cur_p->data [i].logit ;
1766+ }
1767+ logits_sum += cur_p->data [i].logit ;
1768+ valid_count++;
17631769 }
1764- logits_sum += cur_p->data [i].logit ;
17651770 }
1766- float mean = logits_sum/cur_p-> size ;
1771+ float mean = valid_count > 0 ? logits_sum/valid_count : 0 ;
17671772
17681773 // calculate standard deviation
17691774 float acc = 0 ;
17701775 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1771- acc += pow (cur_p->data [i].logit - mean, 2 );
1776+ // Skip -infinity in std calculation
1777+ if (cur_p->data [i].logit != -INFINITY) {
1778+ acc += pow (cur_p->data [i].logit - mean, 2 );
1779+ }
17721780 }
1773- float std = sqrt (acc/cur_p-> size ) ;
1781+ float std = valid_count > 0 ? sqrt (acc/valid_count) : 0 ;
17741782
17751783 // apply mask
17761784 for (size_t i = 0 ; i < cur_p->size ; ++i) {
0 commit comments