@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
17391739    };
17401740}
17411741
1742+ //  infill
1743+ 
1744+ // #define GGML_DEBUG_SAMPLER_INFILL
1745+ 
1746+ struct  llama_sampler_infill  {
1747+     const  struct  llama_vocab  * vocab;
1748+ };
1749+ 
1750+ static  const  char  * llama_sampler_infill_name (const  struct  llama_sampler  * /* smpl*/  ) {
1751+     return  " infill"  ;
1752+ }
1753+ 
1754+ static  void  llama_sampler_infill_apply (struct  llama_sampler  * smpl, llama_token_data_array * cur_p) {
1755+     auto  * ctx = (llama_sampler_infill *) smpl->ctx ;
1756+ 
1757+     llama_sampler_softmax_impl (cur_p);
1758+ 
1759+ #if  defined(GGML_DEBUG_SAMPLER_INFILL)
1760+ #define  LOG_DBG_CUR  LLAMA_LOG_DEBUG
1761+ #else 
1762+ #define  LOG_DBG_CUR (...)
1763+ #endif 
1764+ 
1765+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1766+         LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n "  , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1767+     }
1768+ 
1769+     float  p_txt_sum = 0 .0f ;
1770+     float  p_eog_sum = 0 .0f ;
1771+ 
1772+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1773+         if  (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1774+             p_eog_sum += cur_p->data [i].p ;
1775+         } else  {
1776+             p_txt_sum += cur_p->data [i].p ;
1777+         }
1778+     }
1779+ 
1780+     const  float  rat = p_eog_sum == 0.0  ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED (rat);
1781+ 
1782+     LOG_DBG_CUR (" %s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n "  , __func__, p_txt_sum, p_eog_sum, rat, cur_p->size );
1783+ 
1784+     if  (3 *p_eog_sum*cur_p->size  > p_txt_sum) {
1785+         LOG_DBG_CUR (" %s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n "  , __func__, p_txt_sum/p_eog_sum);
1786+ 
1787+         //  keep just the EOG tokens
1788+         const  auto  size_org = cur_p->size ;
1789+ 
1790+         cur_p->size  = 0 ;
1791+ 
1792+         float  p_sum = 0 .0f ;
1793+ 
1794+         for  (size_t  i = 0 ; i < size_org; ++i) {
1795+             if  (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1796+                 p_sum += cur_p->data [i].p ;
1797+ 
1798+                 cur_p->data [cur_p->size ++] = cur_p->data [i];
1799+             }
1800+         }
1801+ 
1802+         //  normalize probs
1803+         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1804+             cur_p->data [i].p  /= p_sum;
1805+         }
1806+ 
1807+         return ;
1808+     }
1809+ 
1810+     size_t  n_combined = 0 ; GGML_UNUSED (n_combined);
1811+ 
1812+     //  combine tokens with common prefix
1813+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1814+         for  (size_t  j = 0 ; j < cur_p->size ; ++j) {
1815+             if  (cur_p->data [i].logit  == -INFINITY) {
1816+                 break ;
1817+             }
1818+ 
1819+             if  (i == j || cur_p->data [j].logit  == -INFINITY) {
1820+                 continue ;
1821+             }
1822+ 
1823+             if  (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1824+                 if  (cur_p->data [i].p  >  cur_p->data [j].p ) {
1825+                     cur_p->data [i].p  += cur_p->data [j].p ;
1826+                     cur_p->data [j].logit  = -INFINITY;
1827+                     cur_p->data [j].p      = 0 .0f ;
1828+                 } else  {
1829+                     cur_p->data [j].p  += cur_p->data [i].p ;
1830+                     cur_p->data [i].logit  = -INFINITY;
1831+                     cur_p->data [i].p      = 0 .0f ;
1832+                 }
1833+ 
1834+                 n_combined++;
1835+             }
1836+         }
1837+     }
1838+ 
1839+     size_t  n_non_eog = 0 ;
1840+ 
1841+     size_t  size_org = cur_p->size ;
1842+ 
1843+     float  p_sum = 0 .0f ;
1844+     float  thold = 0 .2f ;
1845+ 
1846+     cur_p->size  = 0 ;
1847+ 
1848+     LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n "  , __func__, n_combined, thold);
1849+ 
1850+     for  (size_t  i = 0 ; i < size_org; ++i) {
1851+         const  bool  is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1852+ 
1853+         if  (cur_p->data [i].p  < thold && !is_eog) {
1854+             continue ;
1855+         }
1856+ 
1857+         if  (!is_eog) {
1858+             ++n_non_eog;
1859+         }
1860+ 
1861+         p_sum += cur_p->data [i].p ;
1862+ 
1863+         //  keep this token
1864+         cur_p->data [cur_p->size ++] = cur_p->data [i];
1865+     }
1866+ 
1867+     LOG_DBG_CUR (" %s: n_non_eog = %zu\n "  , __func__, n_non_eog);
1868+ 
1869+     //  if no non-EOG tokens are left -> reduce cur_p to single EOT token
1870+     if  (n_non_eog == 0 ) {
1871+         cur_p->size  = 1 ;
1872+         cur_p->data [0 ].id  = llama_token_eot_impl (*ctx->vocab );
1873+         cur_p->data [0 ].logit  = 1 .0f ;
1874+ 
1875+         return ;
1876+     }
1877+ 
1878+     //  normalize probs
1879+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1880+         cur_p->data [i].p  /= p_sum;
1881+ 
1882+         LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n "  , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1883+     }
1884+ 
1885+     size_org = cur_p->size ;
1886+     p_sum = 0 .0f ;
1887+     thold = 1.0 /(n_non_eog + 1 );
1888+ 
1889+     cur_p->size  = 0 ;
1890+ 
1891+     LOG_DBG_CUR (" %s: applying thold = %.3f\n "  , __func__, thold);
1892+ 
1893+     for  (size_t  i = 0 ; i < size_org; ++i) {
1894+         const  bool  is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1895+ 
1896+         if  (cur_p->data [i].p  < thold && !is_eog) {
1897+             continue ;
1898+         }
1899+ 
1900+         p_sum += cur_p->data [i].p ;
1901+ 
1902+         cur_p->data [cur_p->size ++] = cur_p->data [i];
1903+     }
1904+ 
1905+     //  normalize probs
1906+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1907+         cur_p->data [i].p  /= p_sum;
1908+ 
1909+         LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n "  , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1910+     }
1911+ 
1912+ #undef  LOG_DBG_CUR
1913+ }
1914+ 
1915+ static  struct  llama_sampler  * llama_sampler_infill_clone (const  struct  llama_sampler  * smpl) {
1916+     const  auto  * ctx = (const  llama_sampler_infill *) smpl->ctx ;
1917+     return  llama_sampler_init_infill_impl (*ctx->vocab );
1918+ }
1919+ 
1920+ static  void  llama_sampler_infill_free (struct  llama_sampler  * smpl) {
1921+     delete  (llama_sampler_infill *) smpl->ctx ;
1922+ }
1923+ 
1924+ static  struct  llama_sampler_i  llama_sampler_infill_i = {
1925+     /*  .name   = */   llama_sampler_infill_name,
1926+     /*  .accept = */   nullptr ,
1927+     /*  .apply  = */   llama_sampler_infill_apply,
1928+     /*  .reset  = */   nullptr ,
1929+     /*  .clone  = */   llama_sampler_infill_clone,
1930+     /*  .free   = */   llama_sampler_infill_free,
1931+ };
1932+ 
1933+ struct  llama_sampler  * llama_sampler_init_infill_impl (
1934+         const  struct  llama_vocab  & vocab) {
1935+     return  new  llama_sampler {
1936+         /*  .iface = */   &llama_sampler_infill_i,
1937+         /*  .ctx   = */   new  llama_sampler_infill {
1938+             /*  .vocab = */   &vocab,
1939+         },
1940+     };
1941+ }
1942+ 
17421943//  utils
17431944
17441945uint32_t  llama_sampler_get_seed (const  struct  llama_sampler  * smpl) {
0 commit comments