@@ -1644,6 +1644,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16441644 };
16451645}
16461646
1647+ // infill
1648+
1649+ // #define GGML_DEBUG_SAMPLER_INFILL
1650+
1651+ struct llama_sampler_infill {
1652+ const struct llama_vocab * vocab;
1653+ };
1654+
1655+ static const char * llama_sampler_infill_name (const struct llama_sampler * /* smpl*/ ) {
1656+ return " infill" ;
1657+ }
1658+
1659+ static void llama_sampler_infill_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1660+ auto * ctx = (llama_sampler_infill *) smpl->ctx ;
1661+
1662+ llama_sampler_softmax_impl (cur_p);
1663+
1664+ #if defined(GGML_DEBUG_SAMPLER_INFILL)
1665+ #define LOG_DBG_CUR LLAMA_LOG_DEBUG
1666+ #else
1667+ #define LOG_DBG_CUR (...)
1668+ #endif
1669+
1670+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1671+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1672+ }
1673+
1674+ float p_txt_sum = 0 .0f ;
1675+ float p_eog_sum = 0 .0f ;
1676+
1677+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1678+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1679+ p_eog_sum += cur_p->data [i].p ;
1680+ } else {
1681+ p_txt_sum += cur_p->data [i].p ;
1682+ }
1683+ }
1684+
1685+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED (rat);
1686+
1687+ LOG_DBG_CUR (" %s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n " , __func__, p_txt_sum, p_eog_sum, rat, cur_p->size );
1688+
1689+ if (3 *p_eog_sum*cur_p->size > p_txt_sum) {
1690+ LOG_DBG_CUR (" %s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n " , __func__, p_txt_sum/p_eog_sum);
1691+
1692+ // keep just the EOG tokens
1693+ const auto size_org = cur_p->size ;
1694+
1695+ cur_p->size = 0 ;
1696+
1697+ float p_sum = 0 .0f ;
1698+
1699+ for (size_t i = 0 ; i < size_org; ++i) {
1700+ if (llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id )) {
1701+ p_sum += cur_p->data [i].p ;
1702+
1703+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1704+ }
1705+ }
1706+
1707+ // normalize probs
1708+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1709+ cur_p->data [i].p /= p_sum;
1710+ }
1711+
1712+ return ;
1713+ }
1714+
1715+ size_t n_combined = 0 ; GGML_UNUSED (n_combined);
1716+
1717+ // combine tokens with common prefix
1718+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1719+ for (size_t j = 0 ; j < cur_p->size ; ++j) {
1720+ if (cur_p->data [i].logit == -INFINITY) {
1721+ break ;
1722+ }
1723+
1724+ if (i == j || cur_p->data [j].logit == -INFINITY) {
1725+ continue ;
1726+ }
1727+
1728+ if (llama_token_is_prefix_impl (*ctx->vocab , cur_p->data [i].id , cur_p->data [j].id )) {
1729+ if (cur_p->data [i].p > cur_p->data [j].p ) {
1730+ cur_p->data [i].p += cur_p->data [j].p ;
1731+ cur_p->data [j].logit = -INFINITY;
1732+ cur_p->data [j].p = 0 .0f ;
1733+ } else {
1734+ cur_p->data [j].p += cur_p->data [i].p ;
1735+ cur_p->data [i].logit = -INFINITY;
1736+ cur_p->data [i].p = 0 .0f ;
1737+ }
1738+
1739+ n_combined++;
1740+ }
1741+ }
1742+ }
1743+
1744+ size_t n_non_eog = 0 ;
1745+
1746+ size_t size_org = cur_p->size ;
1747+
1748+ float p_sum = 0 .0f ;
1749+ float thold = 0 .2f ;
1750+
1751+ cur_p->size = 0 ;
1752+
1753+ LOG_DBG_CUR (" %s: n_combined = %zu, applying thold = %.3f\n " , __func__, n_combined, thold);
1754+
1755+ for (size_t i = 0 ; i < size_org; ++i) {
1756+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1757+
1758+ if (cur_p->data [i].p < thold && !is_eog) {
1759+ continue ;
1760+ }
1761+
1762+ if (!is_eog) {
1763+ ++n_non_eog;
1764+ }
1765+
1766+ p_sum += cur_p->data [i].p ;
1767+
1768+ // keep this token
1769+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1770+ }
1771+
1772+ LOG_DBG_CUR (" %s: n_non_eog = %zu\n " , __func__, n_non_eog);
1773+
1774+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
1775+ if (n_non_eog == 0 ) {
1776+ cur_p->size = 1 ;
1777+ cur_p->data [0 ].id = llama_token_eot_impl (*ctx->vocab );
1778+ cur_p->data [0 ].logit = 1 .0f ;
1779+
1780+ return ;
1781+ }
1782+
1783+ // normalize probs
1784+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1785+ cur_p->data [i].p /= p_sum;
1786+
1787+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1788+ }
1789+
1790+ size_org = cur_p->size ;
1791+ p_sum = 0 .0f ;
1792+ thold = 1.0 /(n_non_eog + 1 );
1793+
1794+ cur_p->size = 0 ;
1795+
1796+ LOG_DBG_CUR (" %s: applying thold = %.3f\n " , __func__, thold);
1797+
1798+ for (size_t i = 0 ; i < size_org; ++i) {
1799+ const bool is_eog = llama_token_is_eog_impl (*ctx->vocab , cur_p->data [i].id );
1800+
1801+ if (cur_p->data [i].p < thold && !is_eog) {
1802+ continue ;
1803+ }
1804+
1805+ p_sum += cur_p->data [i].p ;
1806+
1807+ cur_p->data [cur_p->size ++] = cur_p->data [i];
1808+ }
1809+
1810+ // normalize probs
1811+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
1812+ cur_p->data [i].p /= p_sum;
1813+
1814+ LOG_DBG_CUR (" %s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n " , __func__, i, cur_p->data [i].id , cur_p->data [i].p , cur_p->data [i].logit );
1815+ }
1816+
1817+ #undef LOG_DBG_CUR
1818+ }
1819+
1820+ static struct llama_sampler * llama_sampler_infill_clone (const struct llama_sampler * smpl) {
1821+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx ;
1822+ return llama_sampler_init_infill_impl (*ctx->vocab );
1823+ }
1824+
1825+ static void llama_sampler_infill_free (struct llama_sampler * smpl) {
1826+ delete (llama_sampler_infill *) smpl->ctx ;
1827+ }
1828+
1829+ static struct llama_sampler_i llama_sampler_infill_i = {
1830+ /* .name = */ llama_sampler_infill_name,
1831+ /* .accept = */ nullptr ,
1832+ /* .apply = */ llama_sampler_infill_apply,
1833+ /* .reset = */ nullptr ,
1834+ /* .clone = */ llama_sampler_infill_clone,
1835+ /* .free = */ llama_sampler_infill_free,
1836+ };
1837+
1838+ struct llama_sampler * llama_sampler_init_infill_impl (
1839+ const struct llama_vocab & vocab) {
1840+ return new llama_sampler {
1841+ /* .iface = */ &llama_sampler_infill_i,
1842+ /* .ctx = */ new llama_sampler_infill {
1843+ /* .vocab = */ &vocab,
1844+ },
1845+ };
1846+ }
1847+
16471848// utils
16481849
16491850uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
0 commit comments