@@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
17451745
17461746struct llama_sampler_infill {
17471747 const struct llama_vocab * vocab;
1748+
1749+ std::vector<char > buf0;
1750+ std::vector<char > buf1;
17481751};
17491752
17501753static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
@@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
18101813 size_t n_combined = 0 ; GGML_UNUSED (n_combined);
18111814
18121815 // combine tokens with common prefix
1813- for (size_t i = 0 ; i < cur_p->size ; ++i ) {
1814- for (size_t j = 0 ; j < cur_p->size ; ++j ) {
1815- if (cur_p->data [i ].logit == -INFINITY) {
1816+ for (size_t i0 = 0 ; i0 < cur_p->size ; ++i0 ) {
1817+ for (size_t i1 = 0 ; i1 < cur_p->size ; ++i1 ) {
1818+ if (cur_p->data [i0 ].logit == -INFINITY) {
18161819 break ;
18171820 }
18181821
1819- if (i == j || cur_p->data [j ].logit == -INFINITY) {
1822+ if (i0 == i1 || cur_p->data [i1 ].logit == -INFINITY) {
18201823 continue ;
18211824 }
18221825
1823- if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1824- if (cur_p->data [i].p > cur_p->data [j].p ) {
1825- cur_p->data [i].p += cur_p->data [j].p ;
1826- cur_p->data [j].logit = -INFINITY;
1827- cur_p->data [j].p = 0 .0f ;
1828- } else {
1829- cur_p->data [j].p += cur_p->data [i].p ;
1830- cur_p->data [i].logit = -INFINITY;
1831- cur_p->data [i].p = 0 .0f ;
1826+ int len0 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i0].id , ctx->buf0 .data (), ctx->buf0 .size (), 0 , false );
1827+ if (len0 < 0 ) {
1828+ ctx->buf0 .resize (len0);
1829+ len0 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i0].id , ctx->buf0 .data (), ctx->buf0 .size (), 0 , false );
1830+ assert (len0 > 0 );
1831+ }
1832+
1833+ int len1 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i1].id , ctx->buf1 .data (), ctx->buf1 .size (), 0 , false );
1834+ if (len1 < 0 ) {
1835+ ctx->buf1 .resize (len1);
1836+ len1 = llama_token_to_piece_impl (*ctx->vocab , cur_p->data [i1].id , ctx->buf1 .data (), ctx->buf1 .size (), 0 , false );
1837+ assert (len1 > 0 );
1838+ }
1839+
1840+ // token i0 is a prefix of token i1
1841+ if (len0 > 0 && len0 <= len1 && memcmp (ctx->buf0 .data (), ctx->buf1 .data (), len0) == 0 ) {
1842+ int dst = i0;
1843+ int src = i1;
1844+
1845+ // merge into the token with higher probability
1846+ if (cur_p->data [i1].p > cur_p->data [i0].p ) {
1847+ std::swap (dst, src);
18321848 }
18331849
1850+ cur_p->data [dst].p += cur_p->data [src].p ;
1851+ cur_p->data [src].logit = -INFINITY;
1852+ cur_p->data [src].p = 0 .0f ;
1853+
18341854 n_combined++;
18351855 }
18361856 }
@@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
19361956 /* .iface = */ &llama_sampler_infill_i,
19371957 /* .ctx = */ new llama_sampler_infill {
19381958 /* .vocab = */ &vocab,
1959+ /* .buf0 = */ std::vector<char >(512 ),
1960+ /* .buf1 = */ std::vector<char >(512 ),
19391961 },
19401962 };
19411963}
0 commit comments