Skip to content

Commit 436967d

Browse files
committed
llama : improve infill sampler
ggml-ci
1 parent cefb32b commit 436967d

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

include/llama.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,7 @@ extern "C" {
11521152
const llama_logit_bias * logit_bias);
11531153

11541154
// this sampler is meant to be used for fill-in-the-middle infilling
1155-
// it's supposed to be used after top_k sampling
1155+
// it's supposed to be used after top_k + top_p sampling
11561156
//
11571157
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
11581158
// 2. combine probs of tokens that have the same prefix
@@ -1169,7 +1169,7 @@ extern "C" {
11691169
// "hel": 0.8
11701170
// "dummy": 0.1
11711171
//
1172-
// 3. discard non-EOG tokens with low prob (< 0.2)
1172+
// 3. discard non-EOG tokens with low prob
11731173
// 4. if no tokens are left -> pick EOT
11741174
//
11751175
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);

src/llama-sampling.cpp

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16461646

16471647
// infill
16481648

1649-
//#define GGML_DEBUG_SAMPLER_INFILL
1649+
#define GGML_DEBUG_SAMPLER_INFILL
16501650

16511651
struct llama_sampler_infill {
16521652
const struct llama_vocab * vocab;
@@ -1662,10 +1662,14 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16621662
llama_sampler_softmax_impl(cur_p);
16631663

16641664
#if defined(GGML_DEBUG_SAMPLER_INFILL)
1665+
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
1666+
#else
1667+
#define LOG_DBG_CUR(...)
1668+
#endif
1669+
16651670
for (size_t i = 0; i < cur_p->size; ++i) {
1666-
LLAMA_LOG_DEBUG("infill: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1671+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
16671672
}
1668-
#endif
16691673

16701674
float p_txt_sum = 0.0f;
16711675
float p_eog_sum = 0.0f;
@@ -1680,10 +1684,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16801684

16811685
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
16821686

1683-
LLAMA_LOG_DEBUG("infill: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_txt_sum, p_eog_sum, rat, cur_p->size);
1687+
LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
16841688

1685-
if (p_eog_sum*cur_p->size > p_txt_sum) {
1686-
LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum);
1689+
if (3*p_eog_sum*cur_p->size > p_txt_sum) {
1690+
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
16871691

16881692
// keep just the EOG tokens
16891693
const auto size_org = cur_p->size;
@@ -1708,6 +1712,8 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17081712
return;
17091713
}
17101714

1715+
size_t n_combined = 0;
1716+
17111717
// combine tokens with common prefix
17121718
for (size_t i = 0; i < cur_p->size; ++i) {
17131719
for (size_t j = 0; j < cur_p->size; ++j) {
@@ -1729,30 +1735,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17291735
cur_p->data[i].logit = -INFINITY;
17301736
cur_p->data[i].p = 0.0f;
17311737
}
1738+
1739+
n_combined++;
17321740
}
17331741
}
17341742
}
17351743

1736-
const auto size_org = cur_p->size;
1744+
size_t n_non_eog = 0;
17371745

1738-
cur_p->size = 0;
1746+
size_t size_org = cur_p->size;
17391747

17401748
float p_sum = 0.0f;
1749+
float thold = 0.2f;
1750+
1751+
cur_p->size = 0;
1752+
1753+
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
17411754

17421755
for (size_t i = 0; i < size_org; ++i) {
1743-
// discard non-EOG tokens with prob < 0.2
1744-
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1756+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1757+
1758+
if (cur_p->data[i].p < thold && !is_eog) {
17451759
continue;
17461760
}
17471761

1748-
// keep this token
1762+
if (!is_eog) {
1763+
++n_non_eog;
1764+
}
1765+
17491766
p_sum += cur_p->data[i].p;
17501767

1768+
// keep this token
17511769
cur_p->data[cur_p->size++] = cur_p->data[i];
17521770
}
17531771

1754-
// if all probs are -INFINITY -> reduce cur_p to single EOG token
1755-
if (cur_p->size == 0) {
1772+
LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
1773+
1774+
// if no non-EOG tokens are left -> reduce cur_p to single EOT token
1775+
if (n_non_eog == 0) {
17561776
cur_p->size = 1;
17571777
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
17581778
cur_p->data[0].logit = 1.0f;
@@ -1764,8 +1784,37 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17641784
for (size_t i = 0; i < cur_p->size; ++i) {
17651785
cur_p->data[i].p /= p_sum;
17661786

1767-
LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1787+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
17681788
}
1789+
1790+
size_org = cur_p->size;
1791+
p_sum = 0.0f;
1792+
thold = 1.0/(n_non_eog + 1);
1793+
1794+
cur_p->size = 0;
1795+
1796+
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
1797+
1798+
for (size_t i = 0; i < size_org; ++i) {
1799+
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
1800+
1801+
if (cur_p->data[i].p < thold && !is_eog) {
1802+
continue;
1803+
}
1804+
1805+
p_sum += cur_p->data[i].p;
1806+
1807+
cur_p->data[cur_p->size++] = cur_p->data[i];
1808+
}
1809+
1810+
// normalize probs
1811+
for (size_t i = 0; i < cur_p->size; ++i) {
1812+
cur_p->data[i].p /= p_sum;
1813+
1814+
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
1815+
}
1816+
1817+
#undef LOG_DBG_CUR
17691818
}
17701819

17711820
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {

0 commit comments

Comments
 (0)