Skip to content

Commit 4d84652

Browse files
committed
llama : simplify infill sampler
1 parent 59c0756 commit 4d84652

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

examples/llama.vim

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ let s:default_config = {
7070
\ 'n_suffix': 128,
7171
\ 'n_predict': 64,
7272
\ 't_max_prompt_ms': 500,
73-
\ 't_max_predict_ms': 200,
73+
\ 't_max_predict_ms': 500,
7474
\ 'show_info': 2,
7575
\ 'auto_fim': v:true,
7676
\ 'max_line_suffix': 8,
77-
\ 'ring_n_chunks': 32,
77+
\ 'ring_n_chunks': 16,
7878
\ 'ring_chunk_size': 128,
7979
\ 'ring_scope': 1024,
8080
\ }

include/llama.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,9 +1154,8 @@ extern "C" {
11541154
// this sampler is meant to be used for fill-in-the-middle infilling
11551155
// it's supposed to be used after top_k sampling
11561156
//
1157-
// 1. if there is a high-prob token (>= 0.9f) -> skip step 2
1158-
// 2. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
1159-
// 3. combine probs of tokens that have the same prefix
1157+
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
1158+
// 2. combine probs of tokens that have the same prefix
11601159
//
11611160
// example:
11621161
//
@@ -1170,8 +1169,8 @@ extern "C" {
11701169
// "hel": 0.8
11711170
// "dummy": 0.1
11721171
//
1173-
// 4. discard non-EOG tokens with low prob (< 0.2)
1174-
// 5. if no tokens are left -> pick EOT
1172+
// 3. discard non-EOG tokens with low prob (< 0.2)
1173+
// 4. if no tokens are left -> pick EOT
11751174
//
11761175
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
11771176

src/llama-sampling.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,13 +1667,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16671667
}
16681668
#endif
16691669

1670-
float p_max = 0.0f;
16711670
float p_txt_sum = 0.0f;
16721671
float p_eog_sum = 0.0f;
16731672

16741673
for (size_t i = 0; i < cur_p->size; ++i) {
1675-
p_max = fmaxf(p_max, cur_p->data[i].p);
1676-
16771674
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
16781675
p_eog_sum += cur_p->data[i].p;
16791676
} else {
@@ -1683,22 +1680,31 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16831680

16841681
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;
16851682

1686-
LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);
1683+
LLAMA_LOG_DEBUG("infill: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_txt_sum, p_eog_sum, rat, cur_p->size);
16871684

1688-
if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
1685+
if (p_eog_sum*cur_p->size > p_txt_sum) {
16891686
LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum);
16901687

16911688
// keep just the EOG tokens
16921689
const auto size_org = cur_p->size;
16931690

16941691
cur_p->size = 0;
16951692

1693+
float p_sum = 0.0f;
1694+
16961695
for (size_t i = 0; i < size_org; ++i) {
16971696
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1697+
p_sum += cur_p->data[i].p;
1698+
16981699
cur_p->data[cur_p->size++] = cur_p->data[i];
16991700
}
17001701
}
17011702

1703+
// normalize probs
1704+
for (size_t i = 0; i < cur_p->size; ++i) {
1705+
cur_p->data[i].p /= p_sum;
1706+
}
1707+
17021708
return;
17031709
}
17041710

0 commit comments

Comments
 (0)