Skip to content

Commit 4625fef

Browse files
committed
logit_bias: apply configurable escalating EOG bias at low n_remain
give eog an increasing (with length - per token, could be per codepoint in future) bias, only after a configured amount generated add to `sample_apply` an `n_remain` param, which is safer than having logit_bias maintain state for how many times it's called (which would lead to wrong assumptions e.g. when calling multiple times per token). see new command line options (incl a request 'after' instead of 'remain'): -eog, --eog-bias-per-tok N when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: 0.0) -remain, --start-eog-at-remain N start applying -eog bias when this many tokens remain of the -n max (default: 0.0) -after, --start-eog-after N start applying -eog bias after this many tokens generated (default: 1000000000.0); whichever happens first between -remain and -after applies Verified that eog bias was effective at avoiding overgeneration and is a reasonable supplement or alternative to editing the prompt; a *constant* eog bias, already supported in samplers, is likely to allow pathologically short outputs.
1 parent 9de9672 commit 4625fef

File tree

19 files changed

+174
-72
lines changed

19 files changed

+174
-72
lines changed

common/arg.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,15 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
16561656
exit(1); // for other exceptions, we exit with status code 1
16571657
}
16581658

1659+
float &pafter = params.sampling.start_eog_after;
1660+
float &premain = params.sampling.start_eog_at_remain;
1661+
float const premain0 = premain;
1662+
float remain = params.n_predict - pafter;
1663+
if (premain < remain)
1664+
premain = remain;
1665+
if (params.sampling.eog_bias_per_tok)
1666+
LOG_INF("%s: n_predict=%d (first of start_eog_at_remain=%0.3g start_eog_after=%0.3g) => (remain=%0.3g) eog-bias-per-tok=%0.3g\n", __func__, (int) params.n_predict,
1667+
(double) premain0, (double) pafter, (double)premain, (double) params.sampling.eog_bias_per_tok);
16591668
return true;
16601669
}
16611670

@@ -2439,6 +2448,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24392448
}
24402449
}
24412450
).set_sparam());
2451+
add_opt(common_arg(
2452+
{"-eog", "--eog-bias-per-tok"}, "N",
2453+
string_format("when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: %.1f)", (double)params.sampling.eog_bias_per_tok),
2454+
[](common_params & params, const std::string & value) {
2455+
params.sampling.eog_bias_per_tok = std::stof(value);
2456+
}
2457+
).set_sparam());
2458+
add_opt(common_arg(
2459+
{"-remain", "--start-eog-at-remain"}, "N",
2460+
string_format("start applying -eog bias when this many tokens remain of the -n max (default: %.1f)", (double)params.sampling.start_eog_at_remain),
2461+
[](common_params & params, const std::string & value) {
2462+
params.sampling.start_eog_at_remain = std::stof(value);
2463+
}
2464+
).set_sparam());
2465+
add_opt(common_arg(
2466+
{"-after", "--start-eog-after"}, "N",
2467+
string_format("start applying -eog bias after this many tokens generated (default: %.1f); whichever happens first between -remain and -after applies", (double)params.sampling.start_eog_after),
2468+
[](common_params & params, const std::string & value) {
2469+
params.sampling.start_eog_after = std::stof(value);
2470+
}
2471+
).set_sparam());
24422472
add_opt(common_arg(
24432473
{"--grammar"}, "GRAMMAR",
24442474
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),

common/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ struct common_params_sampling {
188188
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
189189
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
190190

191+
float eog_bias_per_tok = 0; // escalating bias added to eog per token after:
192+
/// this many remaining tokens (before applying eog_bias_per_tok) ...
193+
float start_eog_at_remain = 0;
194+
// or (whichever is first) after start_eog_after many generated:
195+
/// (i.e. EOG logit bias = max(0,start_eog_after = max(start_eog_after, n_remain - start_eog_at_remain)) * eog_bias_per_tok)
196+
float start_eog_after = 1e9;
197+
191198
// print the parameters into a string
192199
std::string print() const;
193200
};

common/sampling.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
226226
llama_sampler_init_logit_bias(
227227
llama_vocab_n_tokens(vocab),
228228
params.logit_bias.size(),
229-
params.logit_bias.data()));
229+
params.logit_bias.data(),
230+
params.eog_bias_per_tok,
231+
params.start_eog_at_remain,
232+
vocab));
230233

231234
if (params.mirostat == 0) {
232235
for (const auto & cnstr : params.samplers) {
@@ -336,18 +339,18 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
336339
}
337340
}
338341

339-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
342+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first, float n_remain) {
340343
gsmpl->set_logits(ctx, idx);
341344

342345
auto & grmr = gsmpl->grmr;
343346
auto & chain = gsmpl->chain;
344347
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
345348

346349
if (grammar_first) {
347-
llama_sampler_apply(grmr, &cur_p);
350+
llama_sampler_apply(grmr, &cur_p, n_remain);
348351
}
349352

350-
llama_sampler_apply(chain, &cur_p);
353+
llama_sampler_apply(chain, &cur_p, n_remain);
351354

352355
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
353356

@@ -362,7 +365,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
362365
llama_token_data single_token_data = { id, 1.0f, 0.0f };
363366
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
364367

365-
llama_sampler_apply(grmr, &single_token_data_array);
368+
llama_sampler_apply(grmr, &single_token_data_array, n_remain);
366369

367370
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
368371
if (is_valid) {
@@ -374,23 +377,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
374377
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
375378
gsmpl->set_logits(ctx, idx);
376379

377-
llama_sampler_apply(grmr, &cur_p);
378-
llama_sampler_apply(chain, &cur_p);
380+
llama_sampler_apply(grmr, &cur_p, n_remain);
381+
llama_sampler_apply(chain, &cur_p, n_remain);
379382

380383
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
381384

382385
return cur_p.data[cur_p.selected].id;
383386
}
384387

385-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
388+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float n_remain) {
386389
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
387390

388391
std::vector<llama_token> result;
389392
result.reserve(idxs.size());
390393

391394
size_t i = 0;
392395
for (; i < draft.size(); i++) {
393-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
396+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
394397

395398
common_sampler_accept(gsmpl, id, true);
396399

@@ -402,7 +405,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
402405
}
403406

404407
if (i == draft.size()) {
405-
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
408+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);
406409

407410
common_sampler_accept(gsmpl, id, true);
408411

@@ -412,13 +415,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
412415
return result;
413416
}
414417

415-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
418+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float n_remain) {
416419
std::vector<int> idxs(draft.size() + 1);
417420
for (size_t i = 0; i < idxs.size(); ++i) {
418421
idxs[i] = i;
419422
}
420423

421-
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
424+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, n_remain);
422425
}
423426

424427
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
5858
// if grammar_first is true, the grammar is applied before the samplers (slower)
5959
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
6060
//
61-
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
61+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false, float n_remain = 0);
6262

6363
// generalized version of common_sampler_sample
6464
//
@@ -76,10 +76,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
7676
//
7777
// returns at least 1 token, up to idxs.size()
7878
//
79-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
79+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8080

8181
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
82-
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
82+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);
8383

8484
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8585

common/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@ llama_tokens common_speculative_gen_draft(
310310
llama_decode(ctx_dft, batch);
311311

312312
common_sampler_reset(smpl);
313-
313+
int n_remain = params.n_draft;
314314
// sample n_draft tokens from the draft model
315315
for (int i = 0; i < params.n_draft; ++i) {
316316
common_batch_clear(batch);
317317

318-
common_sampler_sample(smpl, ctx_dft, 0, true);
318+
common_sampler_sample(smpl, ctx_dft, 0, true, --n_remain);
319319

320320
const auto * cur_p = common_sampler_get_candidates(smpl, true);
321321

examples/batched/batched.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ int main(int argc, char ** argv) {
162162

163163
const auto t_main_start = ggml_time_us();
164164

165+
int n_remain = n_predict;
165166
while (n_cur <= n_predict) {
167+
--n_remain;
166168
// prepare the next batch
167169
common_batch_clear(batch);
168170

@@ -173,7 +175,7 @@ int main(int argc, char ** argv) {
173175
continue;
174176
}
175177

176-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
178+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i], n_remain);
177179

178180
// is it an end of generation? -> mark the stream as finished
179181
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {

examples/diffusion/diffusion-cli.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ static void diffusion_generate(llama_context * ctx,
408408
false,
409409
};
410410

411-
llama_sampler_apply(sampler, &cur_p);
411+
float n_remain = params.max_length - pos;
412+
llama_sampler_apply(sampler, &cur_p, n_remain);
412413
output_tokens[pos] = cur_p.data[cur_p.selected].id;
413414
}
414415
}
@@ -433,7 +434,8 @@ static void diffusion_generate(llama_context * ctx,
433434
false,
434435
};
435436

436-
llama_sampler_apply(sampler, &cur_p);
437+
float n_remain = params.max_length - i;
438+
llama_sampler_apply(sampler, &cur_p, n_remain);
437439
llama_token sampled_token = cur_p.data[cur_p.selected].id;
438440

439441
float conf = calculate_confidence(cur_p, params.algorithm, rng);
@@ -477,7 +479,8 @@ static void diffusion_generate(llama_context * ctx,
477479
};
478480

479481
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
480-
llama_sampler_apply(dist_sampler, &conf_array);
482+
float n_remain = params.max_length - i;
483+
llama_sampler_apply(dist_sampler, &conf_array, n_remain);
481484
int32_t selected_idx = conf_array.selected;
482485
int32_t mask_idx = selected_idx;
483486
int32_t pos = mask_positions[mask_idx];

examples/lookahead/lookahead.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ int main(int argc, char ** argv) {
253253

254254
int seq_id_best = 0;
255255

256+
int n_remain = N;
256257
for (int v = 0; v < N; ++v) {
257258
int i_batch = 0;
258259

@@ -274,8 +275,9 @@ int main(int argc, char ** argv) {
274275
}
275276
}
276277

278+
--n_remain;
277279
// sample the next token
278-
id = common_sampler_sample(smpl, ctx, i_batch);
280+
id = common_sampler_sample(smpl, ctx, i_batch, n_remain);
279281

280282
common_sampler_accept(smpl, id, true);
281283

@@ -349,10 +351,11 @@ int main(int argc, char ** argv) {
349351
tokens_j[j] = tokens_j[j + 1];
350352
}
351353

354+
unsigned constexpr NA = (unsigned)-1;
352355
if (v == 0) {
353356
// sample from the last level
354357
for (int i = 0; i < W; i++) {
355-
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
358+
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i, NA);
356359
}
357360
} else {
358361
for (int i = 0; i < W; i++) {

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ int main(int argc, char ** argv){
117117
int i_dft = 0;
118118
while (true) {
119119
// sample from the target model
120-
llama_token id = common_sampler_sample(smpl, ctx, i_dft);
120+
unsigned const n_remain = params.n_predict - n_predict;
121+
llama_token id = common_sampler_sample(smpl, ctx, i_dft, n_remain);
121122

122123
common_sampler_accept(smpl, id, true);
123124

examples/passkey/passkey.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ int main(int argc, char ** argv) {
217217

218218
const auto t_main_start = ggml_time_us();
219219

220+
int n_remain = n_len - n_cur;
220221
while (n_cur <= n_len) {
222+
--n_remain;
221223
// sample the next token
222224
{
223-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
225+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1, n_remain);
224226

225227
// is it an end of generation?
226228
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {

0 commit comments

Comments
 (0)