@@ -1667,13 +1667,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16671667 }
16681668#endif
16691669
1670- float p_max = 0 .0f ;
16711670 float p_txt_sum = 0 .0f ;
16721671 float p_eog_sum = 0 .0f ;
16731672
16741673 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1675- p_max = fmaxf (p_max, cur_p->data [i].p );
1676-
16771674 if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
16781675 p_eog_sum += cur_p->data [i].p ;
16791676 } else {
@@ -1683,22 +1680,31 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16831680
16841681 const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
16851682
1686- LLAMA_LOG_DEBUG (" infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_max , p_txt_sum, p_eog_sum, rat, cur_p->size );
1683+ LLAMA_LOG_DEBUG (" infill: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_txt_sum, p_eog_sum, rat, cur_p->size );
16871684
1688- if (p_max < 0 . 90f && p_eog_sum*cur_p->size > p_txt_sum) {
1685+ if (p_eog_sum*cur_p->size > p_txt_sum) {
16891686 LLAMA_LOG_DEBUG (" infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , p_txt_sum/p_eog_sum);
16901687
16911688 // keep just the EOG tokens
16921689 const auto size_org = cur_p->size ;
16931690
16941691 cur_p->size = 0 ;
16951692
1693+ float p_sum = 0 .0f ;
1694+
16961695 for (size_t i = 0 ; i < size_org; ++i) {
16971696 if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1697+ p_sum += cur_p->data [i].p ;
1698+
16981699 cur_p->data [cur_p->size ++] = cur_p->data [i];
16991700 }
17001701 }
17011702
1703+ // normalize probs
1704+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1705+ cur_p->data [i].p /= p_sum;
1706+ }
1707+
17021708 return ;
17031709 }
17041710
0 commit comments