@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
17391739 };
17401740}
17411741
1742+ // infill
1743+
1744+ // #define GGML_DEBUG_SAMPLER_INFILL
1745+
1746+ struct llama_sampler_infill {
1747+ const struct llama_vocab * vocab;
1748+ };
1749+
1750+ static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
1751+ return " infill" ;
1752+ }
1753+
1754+ static void llama_sampler_infill_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1755+ auto * ctx = (llama_sampler_infill *) smpl->ctx ;
1756+
1757+ llama_sampler_softmax_impl (cur_p);
1758+
1759+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
1760+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1761+ #else
1762+ #define LOG_DBG_CUR (...)
1763+ #endif
1764+
1765+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1766+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1767+ }
1768+
1769+ float p_txt_sum = 0 .0f ;
1770+ float p_eog_sum = 0 .0f ;
1771+
1772+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1773+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1774+ p_eog_sum += cur_p->data [i].p ;
1775+ } else {
1776+ p_txt_sum += cur_p->data [i].p ;
1777+ }
1778+ }
1779+
1780+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED (rat);
1781+
1782+ LOG_DBG_CUR (" %s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , __func__, p_txt_sum, p_eog_sum, rat, cur_p->size );
1783+
1784+ if (3 *p_eog_sum*cur_p->size > p_txt_sum) {
1785+ LOG_DBG_CUR (" %s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , __func__, p_txt_sum/p_eog_sum);
1786+
1787+ // keep just the EOG tokens
1788+ const auto size_org = cur_p->size ;
1789+
1790+ cur_p->size = 0 ;
1791+
1792+ float p_sum = 0 .0f ;
1793+
1794+ for (size_t i = 0 ; i < size_org; ++i) {
1795+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1796+ p_sum += cur_p->data [i].p ;
1797+
1798+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1799+ }
1800+ }
1801+
1802+ // normalize probs
1803+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1804+ cur_p->data [i].p /= p_sum;
1805+ }
1806+
1807+ return ;
1808+ }
1809+
1810+ size_t n_combined = 0 ; GGML_UNUSED (n_combined);
1811+
1812+ // combine tokens with common prefix
1813+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1814+ for (size_t j = 0 ; j < cur_p->size ; ++j) {
1815+ if (cur_p->data [i].logit == -INFINITY) {
1816+ break ;
1817+ }
1818+
1819+ if (i == j || cur_p->data [j].logit == -INFINITY) {
1820+ continue ;
1821+ }
1822+
1823+ if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1824+ if (cur_p->data [i].p > cur_p->data [j].p ) {
1825+ cur_p->data [i].p += cur_p->data [j].p ;
1826+ cur_p->data [j].logit = -INFINITY;
1827+ cur_p->data [j].p = 0 .0f ;
1828+ } else {
1829+ cur_p->data [j].p += cur_p->data [i].p ;
1830+ cur_p->data [i].logit = -INFINITY;
1831+ cur_p->data [i].p = 0 .0f ;
1832+ }
1833+
1834+ n_combined++;
1835+ }
1836+ }
1837+ }
1838+
1839+ size_t n_non_eog = 0 ;
1840+
1841+ size_t size_org = cur_p->size ;
1842+
1843+ float p_sum = 0 .0f ;
1844+ float thold = 0 .2f ;
1845+
1846+ cur_p->size = 0 ;
1847+
1848+ LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n " , __func__, n_combined, thold);
1849+
1850+ for (size_t i = 0 ; i < size_org; ++i) {
1851+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1852+
1853+ if (cur_p->data [i].p < thold && !is_eog) {
1854+ continue ;
1855+ }
1856+
1857+ if (!is_eog) {
1858+ ++n_non_eog;
1859+ }
1860+
1861+ p_sum += cur_p->data [i].p ;
1862+
1863+ // keep this token
1864+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1865+ }
1866+
1867+ LOG_DBG_CUR (" %s: n_non_eog = %zu\n " , __func__, n_non_eog);
1868+
1869+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
1870+ if (n_non_eog == 0 ) {
1871+ cur_p->size = 1 ;
1872+ cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
1873+ cur_p->data [0 ].logit = 1 .0f ;
1874+
1875+ return ;
1876+ }
1877+
1878+ // normalize probs
1879+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1880+ cur_p->data [i].p /= p_sum;
1881+
1882+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1883+ }
1884+
1885+ size_org = cur_p->size ;
1886+ p_sum = 0 .0f ;
1887+ thold = 1.0 /(n_non_eog + 1 );
1888+
1889+ cur_p->size = 0 ;
1890+
1891+ LOG_DBG_CUR (" %s: applying thold = %.3f\n " , __func__, thold);
1892+
1893+ for (size_t i = 0 ; i < size_org; ++i) {
1894+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1895+
1896+ if (cur_p->data [i].p < thold && !is_eog) {
1897+ continue ;
1898+ }
1899+
1900+ p_sum += cur_p->data [i].p ;
1901+
1902+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1903+ }
1904+
1905+ // normalize probs
1906+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1907+ cur_p->data [i].p /= p_sum;
1908+
1909+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1910+ }
1911+
1912+ #undef LOG_DBG_CUR
1913+ }
1914+
1915+ static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
1916+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx ;
1917+ return llama_sampler_init_infill_impl (*ctx->vocab );
1918+ }
1919+
1920+ static void llama_sampler_infill_free (struct llama_sampler * smpl) {
1921+ delete (llama_sampler_infill *) smpl->ctx ;
1922+ }
1923+
1924+ static struct llama_sampler_i llama_sampler_infill_i = {
1925+ /* .name = */ llama_sampler_infill_name,
1926+ /* .accept = */ nullptr ,
1927+ /* .apply = */ llama_sampler_infill_apply,
1928+ /* .reset = */ nullptr ,
1929+ /* .clone = */ llama_sampler_infill_clone,
1930+ /* .free = */ llama_sampler_infill_free,
1931+ };
1932+
1933+ struct llama_sampler * llama_sampler_init_infill_impl (
1934+ const struct llama_vocab & vocab) {
1935+ return new llama_sampler {
1936+ /* .iface = */ &llama_sampler_infill_i,
1937+ /* .ctx = */ new llama_sampler_infill {
1938+ /* .vocab = */ &vocab,
1939+ },
1940+ };
1941+ }
1942+
17421943// utils
17431944
17441945uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
0 commit comments