@@ -1648,9 +1648,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16481648
16491649struct llama_sampler_infill {
16501650 const struct llama_vocab * vocab;
1651-
1652- const float p;
1653- const float p_eog;
16541651};
16551652
16561653static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
@@ -1668,17 +1665,23 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16681665 }
16691666
16701667 float p_max = 0 .0f ;
1668+ float p_txt_sum = 0 .0f ;
16711669 float p_eog_sum = 0 .0f ;
16721670
16731671 for (size_t i = 0 ; i < cur_p->size ; ++i) {
16741672 p_max = fmaxf (p_max, cur_p->data [i].p );
16751673 if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
16761674 p_eog_sum += cur_p->data [i].p ;
1675+ } else {
1676+ p_txt_sum += cur_p->data [i].p ;
16771677 }
16781678 }
16791679
1680- if (p_max < 0 .90f && p_eog_sum > ctx->p_eog ) {
1681- LLAMA_LOG_DEBUG (" infill: all EOG tokens are more likely than p_eog (%f), keeping only EOG tokens\n " , ctx->p_eog );
1680+ const float rat = p_txt_sum / p_eog_sum;
1681+ LLAMA_LOG_DEBUG (" infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_max, p_txt_sum, p_eog_sum, rat, cur_p->size );
1682+
1683+ if (p_max < 0 .90f && p_eog_sum*cur_p->size > p_txt_sum) {
1684+ LLAMA_LOG_DEBUG (" infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , p_txt_sum/p_eog_sum);
16821685
16831686 // keep just the EOG tokens
16841687 const auto size_org = cur_p->size ;
@@ -1717,9 +1720,9 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17171720 }
17181721 }
17191722
1720- // mask non-EOG tokens with prob < ctx->p
1723+ // mask non-EOG tokens with prob < 0.2
17211724 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1722- if (cur_p->data [i].p < ctx-> p && !llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1725+ if (cur_p->data [i].p < 0.2 && !llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
17231726 cur_p->data [i].logit = -INFINITY;
17241727 }
17251728 }
@@ -1753,7 +1756,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17531756
17541757static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
17551758 const auto * ctx = (const llama_sampler_infill *) smpl->ctx ;
1756- return llama_sampler_init_infill_impl (*ctx->vocab , ctx-> p , ctx-> p_eog );
1759+ return llama_sampler_init_infill_impl (*ctx->vocab );
17571760}
17581761
17591762static void llama_sampler_infill_free (struct llama_sampler * smpl) {
@@ -1770,15 +1773,11 @@ static struct llama_sampler_i llama_sampler_infill_i = {
17701773};
17711774
17721775struct llama_sampler * llama_sampler_init_infill_impl (
1773- const struct llama_vocab & vocab,
1774- float p,
1775- float p_eog) {
1776+ const struct llama_vocab & vocab) {
17761777 return new llama_sampler {
17771778 /* .iface = */ &llama_sampler_infill_i,
17781779 /* .ctx = */ new llama_sampler_infill {
17791780 /* .vocab = */ &vocab,
1780- /* .p = */ p,
1781- /* .p_eog = */ p_eog,
17821781 },
17831782 };
17841783}
0 commit comments