Skip to content

Commit 4a9ceca

Browse files
committed
llama : simplify infill sampler
1 parent 141c5ce commit 4a9ceca

File tree

9 files changed

+27
-50
lines changed

9 files changed

+27
-50
lines changed

common/arg.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -947,20 +947,6 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
947947
params.sparams.tfs_z = std::stof(value);
948948
}
949949
).set_sparam());
950-
add_opt(llama_arg(
951-
{"--infill-p"}, "N",
952-
string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p),
953-
[](gpt_params & params, const std::string & value) {
954-
params.sparams.infill_p = std::stof(value);
955-
}
956-
).set_sparam());
957-
add_opt(llama_arg(
958-
{"--infill-p-eog"}, "N",
959-
string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog),
960-
[](gpt_params & params, const std::string & value) {
961-
params.sparams.infill_p_eog = std::stof(value);
962-
}
963-
).set_sparam());
964950
add_opt(llama_arg(
965951
{"--typical"}, "N",
966952
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),

common/common.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@ struct gpt_sampler_params {
114114
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
115115
float dynatemp_range = 0.00f; // 0.0 = disabled
116116
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
117-
float infill_p = 0.80f;
118-
float infill_p_eog = 0.01f;
119117
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
120118
float penalty_repeat = 1.00f; // 1.0 = disabled
121119
float penalty_freq = 0.00f; // 0.0 = disabled

common/sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
194194
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195195
break;
196196
case GPT_SAMPLER_TYPE_INFILL:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog));
197+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
198198
break;
199199
default:
200200
GGML_ASSERT(false && "unknown sampler type");

examples/llama.vim

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ function! llama#fim(is_auto) abort
9393
"\ 'stop': g:llama_config.stop,
9494
\ 'n_predict': g:llama_config.n_predict,
9595
\ 'penalty_last_n': 0,
96-
\ 'top_k': 5,
97-
\ 'infill_p': 0.20,
98-
\ 'infill_p_eog': 0.001,
96+
\ 'top_k': 100,
9997
\ 'stream': v:false,
10098
\ 'samplers': ["top_k", "infill"],
10199
"\ 'cache_prompt': v:true,
@@ -180,15 +178,15 @@ function! s:fim_auto()
180178
call jobstop(s:current_job)
181179
endif
182180

183-
if reltimefloat(reltime(s:t_fim_last)) < 0.001*250
181+
if reltimefloat(reltime(s:t_fim_last)) < 500*0.001
184182
if s:timer_fim != -1
185183
call timer_stop(s:timer_fim)
186184
let s:timer_fim = -1
187185
endif
188186
endif
189187

190188
let s:t_fim_last = reltime()
191-
let s:timer_fim = timer_start(250, {-> llama#fim(v:true)})
189+
let s:timer_fim = timer_start(500, {-> llama#fim(v:true)})
192190
endfunction
193191

194192

examples/server/server.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,6 @@ struct server_context {
894894
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
895895
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
896896
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
897-
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
898-
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
899897
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
900898
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
901899
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
@@ -1261,8 +1259,6 @@ struct server_context {
12611259
{"min_p", slot.sparams.min_p},
12621260
{"tfs_z", slot.sparams.tfs_z},
12631261
{"typical_p", slot.sparams.typ_p},
1264-
{"infill_p", slot.sparams.infill_p},
1265-
{"infill_p_eog", slot.sparams.infill_p_eog},
12661262
{"repeat_last_n", slot.sparams.penalty_last_n},
12671263
{"repeat_penalty", slot.sparams.penalty_repeat},
12681264
{"presence_penalty", slot.sparams.penalty_present},

include/llama.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,8 +1150,11 @@ extern "C" {
11501150
int32_t n_logit_bias,
11511151
const llama_logit_bias * logit_bias);
11521152

1153-
// 1. if there is a high-prob token (>= 0.9f) - pick it
1154-
// 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away
1153+
// this sampler is meant to be used for fill-in-the-middle infilling
1154+
// it's supposed to be used after top_k sampling and will leave a single candidate token
1155+
//
1156+
// 1. if there is a high-prob token (>= 0.9f) -> pick it
1157+
// 2. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
11551158
// 3. combine probs of tokens that have the same prefix
11561159
//
11571160
// example:
@@ -1166,10 +1169,9 @@ extern "C" {
11661169
// "hel": 0.8
11671170
// "dummy": 0.1
11681171
//
1169-
LLAMA_API struct llama_sampler * llama_sampler_init_infill(
1170-
const struct llama_model * model,
1171-
float p,
1172-
float p_eog);
1172+
// 4. pick the token with the highest probability
1173+
//
1174+
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
11731175

11741176
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
11751177
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

src/llama-sampling.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,9 +1648,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
16481648

16491649
struct llama_sampler_infill {
16501650
const struct llama_vocab * vocab;
1651-
1652-
const float p;
1653-
const float p_eog;
16541651
};
16551652

16561653
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@@ -1668,17 +1665,23 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
16681665
}
16691666

16701667
float p_max = 0.0f;
1668+
float p_txt_sum = 0.0f;
16711669
float p_eog_sum = 0.0f;
16721670

16731671
for (size_t i = 0; i < cur_p->size; ++i) {
16741672
p_max = fmaxf(p_max, cur_p->data[i].p);
16751673
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
16761674
p_eog_sum += cur_p->data[i].p;
1675+
} else {
1676+
p_txt_sum += cur_p->data[i].p;
16771677
}
16781678
}
16791679

1680-
if (p_max < 0.90f && p_eog_sum > ctx->p_eog) {
1681-
LLAMA_LOG_DEBUG("infill: all EOG tokens are more likely than p_eog (%f), keeping only EOG tokens\n", ctx->p_eog);
1680+
const float rat = p_txt_sum / p_eog_sum;
1681+
LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);
1682+
1683+
if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
1684+
LLAMA_LOG_DEBUG("infill: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", p_txt_sum/p_eog_sum);
16821685

16831686
// keep just the EOG tokens
16841687
const auto size_org = cur_p->size;
@@ -1717,9 +1720,9 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17171720
}
17181721
}
17191722

1720-
// mask non-EOG tokens with prob < ctx->p
1723+
// mask non-EOG tokens with prob < 0.2
17211724
for (size_t i = 0; i < cur_p->size; ++i) {
1722-
if (cur_p->data[i].p < ctx->p && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
1725+
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
17231726
cur_p->data[i].logit = -INFINITY;
17241727
}
17251728
}
@@ -1753,7 +1756,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
17531756

17541757
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
17551758
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
1756-
return llama_sampler_init_infill_impl(*ctx->vocab, ctx->p, ctx->p_eog);
1759+
return llama_sampler_init_infill_impl(*ctx->vocab);
17571760
}
17581761

17591762
static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -1770,15 +1773,11 @@ static struct llama_sampler_i llama_sampler_infill_i = {
17701773
};
17711774

17721775
struct llama_sampler * llama_sampler_init_infill_impl(
1773-
const struct llama_vocab & vocab,
1774-
float p,
1775-
float p_eog) {
1776+
const struct llama_vocab & vocab) {
17761777
return new llama_sampler {
17771778
/* .iface = */ &llama_sampler_infill_i,
17781779
/* .ctx = */ new llama_sampler_infill {
17791780
/* .vocab = */ &vocab,
1780-
/* .p = */ p,
1781-
/* .p_eog = */ p_eog,
17821781
},
17831782
};
17841783
}

src/llama-sampling.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,4 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
2727
const char * grammar_root);
2828

2929
struct llama_sampler * llama_sampler_init_infill_impl(
30-
const struct llama_vocab & vocab,
31-
float p,
32-
float p_eog);
30+
const struct llama_vocab & vocab);

src/llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21817,8 +21817,8 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
2181721817
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
2181821818
}
2181921819

21820-
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model, float p, float p_eog) {
21821-
return llama_sampler_init_infill_impl(model->vocab, p, p_eog);
21820+
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
21821+
return llama_sampler_init_infill_impl(model->vocab);
2182221822
}
2182321823

2182421824
//

0 commit comments

Comments
 (0)