@@ -1663,7 +1663,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16631663
16641664#if defined(GGML_DEBUG_SAMPLER_INFILL)
16651665 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1666- LLAMA_LOG_DEBUG (" infill: cur_p[%zu ] = { id: %d , p: %f , logit: %f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1666+ LLAMA_LOG_DEBUG (" infill: cur_p[%3zu ] = { id: %6d , p: %.6f , logit: %6.3f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
16671667 }
16681668#endif
16691669
@@ -1673,14 +1673,16 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16731673
16741674 for (size_t i = 0 ; i < cur_p->size ; ++i) {
16751675 p_max = fmaxf (p_max, cur_p->data [i].p );
1676+
16761677 if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
16771678 p_eog_sum += cur_p->data [i].p ;
16781679 } else {
16791680 p_txt_sum += cur_p->data [i].p ;
16801681 }
16811682 }
16821683
1683- const float rat = p_txt_sum / p_eog_sum;
1684+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
1685+
16841686 LLAMA_LOG_DEBUG (" infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_max, p_txt_sum, p_eog_sum, rat, cur_p->size );
16851687
16861688 if (p_max < 0 .90f && p_eog_sum*cur_p->size > p_txt_sum) {
@@ -1712,48 +1714,50 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17121714 }
17131715
17141716 if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1715- if (cur_p->data [i].p > cur_p->data [j].p ) {
1717+ if (cur_p->data [i].p > cur_p->data [j].p ) {
17161718 cur_p->data [i].p += cur_p->data [j].p ;
17171719 cur_p->data [j].logit = -INFINITY;
1720+ cur_p->data [j].p = 0 .0f ;
17181721 } else {
17191722 cur_p->data [j].p += cur_p->data [i].p ;
17201723 cur_p->data [i].logit = -INFINITY;
1724+ cur_p->data [i].p = 0 .0f ;
17211725 }
17221726 }
17231727 }
17241728 }
17251729
1726- // mask non-EOG tokens with prob < 0.2
1727- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1730+ const auto size_org = cur_p->size ;
1731+
1732+ cur_p->size = 0 ;
1733+
1734+ float p_sum = 0 .0f ;
1735+
1736+ for (size_t i = 0 ; i < size_org; ++i) {
1737+ // discard non-EOG tokens with prob < 0.2
17281738 if (cur_p->data [i].p < 0.2 && !llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1729- cur_p-> data [i]. logit = -INFINITY ;
1739+ continue ;
17301740 }
1731- }
17321741
1733- // determine the token with max logit
1734- float l_max = -INFINITY;
1735- int i_max = -1 ;
1736- for (size_t i = 0 ; i < cur_p->size ; ++i) {
1737- if (cur_p->data [i].logit > l_max) {
1738- l_max = cur_p->data [i].logit ;
1739- i_max = i;
1740- }
1742+ // keep this token
1743+ p_sum += cur_p->data [i].p ;
1744+
1745+ cur_p->data [cur_p->size ++] = cur_p->data [i];
17411746 }
17421747
17431748 // if all probs are -INFINITY -> reduce cur_p to single EOG token
1744- if (i_max == - 1 ) {
1749+ if (cur_p-> size == 0 ) {
17451750 cur_p->size = 1 ;
17461751 cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
17471752 cur_p->data [0 ].logit = 1 .0f ;
17481753
17491754 return ;
17501755 }
17511756
1752- // pick the best token
1753- cur_p->size = 1 ;
1754- cur_p->data [0 ] = cur_p->data [i_max];
1755-
1757+ // normalize probs
17561758 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1759+ cur_p->data [i].p /= p_sum;
1760+
17571761 LLAMA_LOG_DEBUG (" after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
17581762 }
17591763}
0 commit comments