@@ -1646,7 +1646,7 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16461646
16471647// infill
16481648
1649- // #define GGML_DEBUG_SAMPLER_INFILL
1649+ #define GGML_DEBUG_SAMPLER_INFILL
16501650
16511651struct llama_sampler_infill {
16521652 const struct llama_vocab * vocab;
@@ -1662,10 +1662,14 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16621662 llama_sampler_softmax_impl (cur_p);
16631663
16641664#if defined(GGML_DEBUG_SAMPLER_INFILL)
1665+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1666+ #else
1667+ #define LOG_DBG_CUR (...)
1668+ #endif
1669+
16651670 for (size_t i = 0 ; i < cur_p->size ; ++i) {
1666- LLAMA_LOG_DEBUG ( " infill : cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1671+ LOG_DBG_CUR ( " %s : cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__ , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
16671672 }
1668- #endif
16691673
16701674 float p_txt_sum = 0 .0f ;
16711675 float p_eog_sum = 0 .0f ;
@@ -1680,10 +1684,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16801684
16811685 const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
16821686
1683- LLAMA_LOG_DEBUG ( " infill : p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , p_txt_sum, p_eog_sum, rat, cur_p->size );
1687+ LOG_DBG_CUR ( " %s : p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , __func__ , p_txt_sum, p_eog_sum, rat, cur_p->size );
16841688
1685- if (p_eog_sum*cur_p->size > p_txt_sum) {
1686- LLAMA_LOG_DEBUG ( " infill : the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , p_txt_sum/p_eog_sum);
1689+ if (3 * p_eog_sum*cur_p->size > p_txt_sum) {
1690+ LOG_DBG_CUR ( " %s : the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , __func__ , p_txt_sum/p_eog_sum);
16871691
16881692 // keep just the EOG tokens
16891693 const auto size_org = cur_p->size ;
@@ -1708,6 +1712,8 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17081712 return ;
17091713 }
17101714
1715+ size_t n_combined = 0 ;
1716+
17111717 // combine tokens with common prefix
17121718 for (size_t i = 0 ; i < cur_p->size ; ++i) {
17131719 for (size_t j = 0 ; j < cur_p->size ; ++j) {
@@ -1729,30 +1735,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17291735 cur_p->data [i].logit = -INFINITY;
17301736 cur_p->data [i].p = 0 .0f ;
17311737 }
1738+
1739+ n_combined++;
17321740 }
17331741 }
17341742 }
17351743
1736- const auto size_org = cur_p-> size ;
1744+ size_t n_non_eog = 0 ;
17371745
1738- cur_p->size = 0 ;
1746+ size_t size_org = cur_p->size ;
17391747
17401748 float p_sum = 0 .0f ;
1749+ float thold = 0 .2f ;
1750+
1751+ cur_p->size = 0 ;
1752+
1753+ LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n " , __func__, n_combined, thold);
17411754
17421755 for (size_t i = 0 ; i < size_org; ++i) {
1743- // discard non-EOG tokens with prob < 0.2
1744- if (cur_p->data [i].p < 0.2 && !llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1756+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1757+
1758+ if (cur_p->data [i].p < thold && !is_eog) {
17451759 continue ;
17461760 }
17471761
1748- // keep this token
1762+ if (!is_eog) {
1763+ ++n_non_eog;
1764+ }
1765+
17491766 p_sum += cur_p->data [i].p ;
17501767
1768+ // keep this token
17511769 cur_p->data [cur_p->size ++] = cur_p->data [i];
17521770 }
17531771
1754- // if all probs are -INFINITY -> reduce cur_p to single EOG token
1755- if (cur_p->size == 0 ) {
1772+ LOG_DBG_CUR (" %s: n_non_eog = %zu\n " , __func__, n_non_eog);
1773+
1774+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
1775+ if (n_non_eog == 0 ) {
17561776 cur_p->size = 1 ;
17571777 cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
17581778 cur_p->data [0 ].logit = 1 .0f ;
@@ -1764,8 +1784,37 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17641784 for (size_t i = 0 ; i < cur_p->size ; ++i) {
17651785 cur_p->data [i].p /= p_sum;
17661786
1767- LLAMA_LOG_DEBUG ( " after : cur_p[%zu ] = { id: %d , p: %f , logit: %f }\n " , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1787+ LOG_DBG_CUR ( " %s : cur_p[%3zu ] = { id: %6d , p: %.6f , logit: %6.3f }\n " , __func__ , i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
17681788 }
1789+
1790+ size_org = cur_p->size ;
1791+ p_sum = 0 .0f ;
1792+ thold = 1.0 /(n_non_eog + 1 );
1793+
1794+ cur_p->size = 0 ;
1795+
1796+ LOG_DBG_CUR (" %s: applying thold = %.3f\n " , __func__, thold);
1797+
1798+ for (size_t i = 0 ; i < size_org; ++i) {
1799+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1800+
1801+ if (cur_p->data [i].p < thold && !is_eog) {
1802+ continue ;
1803+ }
1804+
1805+ p_sum += cur_p->data [i].p ;
1806+
1807+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1808+ }
1809+
1810+ // normalize probs
1811+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1812+ cur_p->data [i].p /= p_sum;
1813+
1814+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1815+ }
1816+
1817+ #undef LOG_DBG_CUR
17691818}
17701819
17711820static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
0 commit comments