Skip to content

Commit b011eab

Browse files
committed
Added DRY sampling ggml-org/llama.cpp#6839
1 parent fa6fb88 commit b011eab

File tree

7 files changed

+171
-16
lines changed

7 files changed

+171
-16
lines changed

base/llama-addon.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,4 +380,94 @@ void llama_sample_entropy_addon(struct llama_context * ctx, llama_token_data_arr
380380
#endif
381381

382382
llama_set_time(ctx, t_start_sample_us);
383+
}
384+
385+
void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
386+
// skip dry sampler if we don't have a previous token
387+
if (last_tokens_size < 1) return;
388+
389+
// get the last token
390+
auto last_token = last_tokens[last_tokens_size - 1];
391+
392+
// if last token is part of the sequence breakers, skip whole sampler
393+
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {
394+
return;
395+
}
396+
397+
// create an unordered map of "next tokens" <-> max match length
398+
std::unordered_map<llama_token, size_t> match_lengths;
399+
400+
// loop through each previous token (exclude the last token)
401+
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
402+
// skip if the compare token is not the same as the last token
403+
if (last_tokens[i] != last_token) {
404+
continue;
405+
}
406+
407+
// get the next token (i + 1 is always less than last_tokens_size)
408+
auto next_token = last_tokens[i + 1];
409+
410+
// if next token is part of the sequence breakers, skip
411+
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
412+
continue;
413+
}
414+
415+
// try to extend the match backwards (match length starts at 1 because last token is already matched)
416+
size_t match_length = 1;
417+
418+
// loop through the previous tokens
419+
for (;; match_length++) {
420+
// if we have reached the start of our last tokens, break
421+
if (i < match_length) break;
422+
423+
// compare token starts at our prev index, going backwards by match length
424+
auto compare_token = last_tokens[i - match_length];
425+
426+
// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
427+
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
428+
429+
// break out of the match if any tokens don't match
430+
if (compare_token != head_token) {
431+
break;
432+
}
433+
434+
// if compare token is part of the sequence breakers, break out of the match
435+
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
436+
break;
437+
}
438+
}
439+
440+
// Check if the next token exists in the map
441+
auto it = match_lengths.find(next_token);
442+
443+
if (it == match_lengths.end()) {
444+
// Key does not exist, insert the new value
445+
match_lengths[next_token] = match_length;
446+
} else {
447+
// Key exists, update it with the max of the new value or the existing value
448+
it->second = std::max(it->second, match_length);
449+
}
450+
}
451+
452+
// apply penalties
453+
for (const auto& pair : match_lengths) {
454+
auto next_token = pair.first;
455+
auto match_length = pair.second;
456+
457+
// if the match length is greater than or equal to our allowed length in config, we apply penalities
458+
if (match_length >= (size_t)dry_allowed_length) {
459+
460+
// find our next token in the candidates->data
461+
for (size_t i = 0; i < candidates->size; ++i) {
462+
if (candidates->data[i].id == next_token) {
463+
// calculate the penalty
464+
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);
465+
466+
// apply the dry penalty
467+
candidates->data[i].logit -= penalty;
468+
break;
469+
}
470+
}
471+
}
472+
}
383473
}

base/llama-addon.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,14 @@
3939
float temp,
4040
float smoothing_factor,
4141
float smoothing_curve);
42+
43+
/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
44+
LLAMA_API void llama_sample_dry(
45+
llama_token_data_array * candidates,
46+
const llama_token * last_tokens,
47+
size_t last_tokens_size,
48+
float dry_base,
49+
float dry_multiplier,
50+
int dry_allowed_length,
51+
const llama_token * dry_seq_breakers,
52+
size_t dry_seq_breakers_size);

base/sampling.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,22 @@ llama_token llama_sampling_sample(
189189
const float smoothing_factor = params.smoothing_factor;
190190
const float smoothing_curve = params.smoothing_curve;
191191
const float dynatemp_range = params.dynatemp_range;
192+
//repetition
192193
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
193194
const float penalty_repeat = params.penalty_repeat;
194195
const float penalty_freq = params.penalty_freq;
195196
const float penalty_present = params.penalty_present;
196197
const float penalty_threshold = params.penalty_threshold;
198+
// DRY
199+
const float dry_multiplier = params.dry_multiplier;
200+
const float dry_base = params.dry_base;
201+
const uint32_t dry_allowed_length = params.dry_allowed_length;
202+
const uint32_t dry_penalty_last_n = params.dry_penalty_last_n;
203+
// mirostat
197204
const int mirostat = params.mirostat;
198205
const float mirostat_tau = params.mirostat_tau;
199206
const float mirostat_eta = params.mirostat_eta;
207+
200208
const bool penalize_nl = params.penalize_nl;
201209

202210
auto & prev = ctx_sampling->prev;
@@ -248,6 +256,17 @@ llama_token llama_sampling_sample(
248256
}
249257
}
250258

259+
// apply DRY penalties
260+
{
261+
const int penalty_tokens_used_size = std::min(prev.size(), (size_t)dry_penalty_last_n);
262+
if (penalty_tokens_used_size) {
263+
llama_sample_dry(&cur_p,
264+
prev.data() + prev.size() - penalty_tokens_used_size,
265+
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
266+
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
267+
}
268+
}
269+
251270
if (ctx_sampling->grammar != NULL) {
252271
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
253272
}

base/sampling.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ typedef struct llama_sampling_params {
5353
float mirostat_tau = 5.00f; // target entropy
5454
float mirostat_eta = 0.10f; // learning rate
5555
bool penalize_nl = true; // consider newlines as a repeatable token
56+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
57+
float dry_base = 1.75f;
58+
uint32_t dry_allowed_length = 2;
59+
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
5660
//std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
5761
std::string samplers_sequence = "kfysmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
5862

@@ -64,6 +68,7 @@ typedef struct llama_sampling_params {
6468
float cfg_scale = 1.f; // how strong is guidance
6569

6670
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
71+
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler
6772
} llama_sampling_params;
6873

6974
// general sampler context

chat_plain.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,11 @@ class chat
505505
std::string name_penalty_threshold = fullnames ? "penalty_threshold" : "p_t";
506506
std::string name_penalty_freq = fullnames ? "penalty_freq" : "p_f";
507507
std::string name_penalty_present = fullnames ? "penalty_present" : "p_p";
508+
//DRY
509+
std::string name_dry_multiplier = fullnames ? "dry_multiplier" : "d_m";
510+
std::string name_dry_base = fullnames ? "dry_base" : "d_b";
511+
std::string name_dry_allowed_length = fullnames ? "dry_allowed_length" : "d_l";
512+
std::string name_dry_penalty_last_n = fullnames ? "dry_penalty_last_n" : "d_n";
508513

509514
std::string name_temp = fullnames ? "temp" : "T";
510515
std::string name_dynatemp_range = fullnames ? "dynatemp_range" : "dT";
@@ -520,12 +525,15 @@ class chat
520525
std::string name_top_p = fullnames ? "top_p" : "P";
521526
std::string name_min_p = fullnames ? "min_p" : "I";
522527

523-
if (params.sparams.penalty_repeat != paramsDefault.sparams.penalty_repeat) result += "->" + name_penalty_repeat + " = " + std::to_string(params.sparams.penalty_repeat);
524-
if (params.sparams.penalty_threshold != paramsDefault.sparams.penalty_threshold) result += "->" + name_penalty_threshold + " = " + std::to_string(params.sparams.penalty_threshold);
525-
if (params.sparams.penalty_freq != paramsDefault.sparams.penalty_freq) result += "->" + name_penalty_freq + " = " + std::to_string(params.sparams.penalty_freq);
526-
if (params.sparams.penalty_present != paramsDefault.sparams.penalty_present) result += "->" + name_penalty_present + " = " + std::to_string(params.sparams.penalty_present);
527-
528-
528+
if (params.sparams.penalty_repeat != paramsDefault.sparams.penalty_repeat) result += std::format("-> {} = {:.3f}", name_penalty_repeat, params.sparams.penalty_repeat);
529+
if (params.sparams.penalty_threshold != paramsDefault.sparams.penalty_threshold) result += std::format("-> {} = {:.3f}", name_penalty_threshold, params.sparams.penalty_threshold);
530+
if (params.sparams.penalty_freq != paramsDefault.sparams.penalty_freq) result += std::format("-> {} = {:.3f}", name_penalty_freq, params.sparams.penalty_freq);
531+
if (params.sparams.penalty_present != paramsDefault.sparams.penalty_present) result += std::format("-> {} = {:.3f}", name_penalty_present, params.sparams.penalty_present);
532+
//DRY
533+
if (params.sparams.dry_multiplier != paramsDefault.sparams.dry_multiplier) result += std::format("-> {} = {:.3f}", name_dry_multiplier, params.sparams.dry_multiplier);
534+
if (params.sparams.dry_base != paramsDefault.sparams.dry_base) result += std::format("-> {} = {:.3f}", name_dry_base, params.sparams.dry_base);
535+
if (params.sparams.dry_allowed_length != paramsDefault.sparams.dry_allowed_length) result += std::format("-> {} = {}", name_dry_allowed_length, params.sparams.dry_allowed_length);
536+
if (params.sparams.dry_penalty_last_n != paramsDefault.sparams.dry_penalty_last_n) result += std::format("-> {} = {}", name_dry_penalty_last_n, params.sparams.dry_penalty_last_n);
529537
// mirostat is special
530538
if (params.sparams.mirostat != paramsDefault.sparams.mirostat) {
531539
if (params.sparams.dynatemp_range > 0) {
@@ -537,7 +545,7 @@ class chat
537545
result += std::format("/{:.2f}*{:.2f}", params.sparams.smoothing_factor, params.sparams.smoothing_curve);
538546
}
539547
result += "-> " + name_mirostat + " = " + std::to_string(params.sparams.mirostat);
540-
result += std::format("; {} = {:.2f}", name_mirostat_tau, params.sparams.mirostat_tau);
548+
result += std::format("; {} = {:.2f}", name_mirostat_tau, params.sparams.mirostat_tau);
541549
result += std::format("; {} = {:.2f}", name_mirostat_eta, params.sparams.mirostat_eta);
542550
} else {
543551
for (auto s : params.sparams.samplers_sequence){
@@ -1517,14 +1525,14 @@ class chat
15171525

15181526

15191527
if (input_echo) {
1520-
printf("-pei");
1528+
//printf("-pei");
15211529
for (auto id : embd) {
15221530
//std::string tknStr = llama_token_to_string(ctx, id);
15231531
const std::string tknStr = llama_token_to_piece(ctx, id);
15241532
//result += (std::string) tknStr;
15251533
result += tknStr;
15261534
//if (streaming) printf("%s", tknStr);
1527-
std::cout<<tknStr;
1535+
//std::cout<<tknStr;
15281536
}
15291537

15301538
}
@@ -1638,7 +1646,7 @@ class chat
16381646
// initial (instruct) processing
16391647
std::string process_prompt(bool consoleOutput = true, bool verbose = false) {
16401648

1641-
printf("Starting initial prompt processing...\n");
1649+
if (debug) printf("Starting initial prompt processing...\n");
16421650

16431651
std::string result;
16441652
//std::cout << " * " << std::endl;
@@ -1684,7 +1692,7 @@ class chat
16841692
if (verbose) {
16851693
if (!streaming) std::cout << result << " ";
16861694

1687-
printf("Return generate: prompt processed\n");
1695+
if (debug) printf("Return generate: prompt processed\n");
16881696
}
16891697

16901698
// get_speed();

include/jsonParams.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,8 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
480480
if (checkJString(config, "samplers_sequence")) params.sparams.samplers_sequence = config["samplers_sequence"];
481481
if (checkJString(config, "bos")) params.bos = config["bos"];
482482
if (checkJString(config, "eos")) params.eos = config["eos"];
483-
484483
if (checkJNum(config, "seed")) params.seed = config["seed"];
484+
// threading
485485
if (checkJNum(config, "n_threads")) params.n_threads = config["n_threads"];
486486
if (checkJNum(config, "n_threads_batch")) params.n_threads_batch = config["n_threads_batch"];
487487
if (checkJNum(config, "n_gpu_layers")) params.n_gpu_layers = config["n_gpu_layers"];
@@ -493,7 +493,7 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
493493
if (checkJNum(config, "min_keep")) params.sparams.min_keep = config["min_keep"];
494494
if (checkJNum(config, "n_batch")) params.n_batch = config["n_batch"];
495495
if (checkJNum(config, "n_ubatch")) params.n_ubatch = config["n_ubatch"];
496-
496+
//sampling
497497
load_param_num(config, "temp", params.sparams.temp, params.sparams.temp_func);
498498
load_param_num(config, "dynatemp_range", params.sparams.dynatemp_range, params.sparams.dynatemp_range_func);
499499

@@ -507,14 +507,23 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
507507
//if (checkJNum(config, "p_step")) params.sparams.p_step = config["p_step"];
508508
load_param_num(config, "p_step", params.sparams.p_step, params.sparams.p_step_func);
509509
if (checkJNum(config, "tfs_z")) params.sparams.tfs_z = config["tfs_z"];
510+
//penalties
510511
if (checkJNum(config, "repeat_penalty")) params.sparams.penalty_repeat = config["repeat_penalty"];
512+
if (checkJNum(config, "penalty_repeat")) params.sparams.penalty_repeat = config["penalty_repeat"];
511513
if (checkJNum(config, "penalty_threshold")) params.sparams.penalty_threshold = config["penalty_threshold"];
512514
if (checkJNum(config, "frequency_penalty")) params.sparams.penalty_freq = config["frequency_penalty"];
513515
if (checkJNum(config, "presence_penalty")) params.sparams.penalty_present = config["presence_penalty"];
516+
//DRY
517+
if (checkJNum(config, "dry_multiplier")) params.sparams.dry_multiplier = config["dry_multiplier"];
518+
if (checkJNum(config, "dry_base")) params.sparams.dry_base = config["dry_base"];
519+
if (checkJNum(config, "dry_allowed_length")) params.sparams.dry_allowed_length = config["dry_allowed_length"];
520+
if (checkJNum(config, "dry_penalty_last_n")) params.sparams.dry_penalty_last_n = config["dry_penalty_last_n"];
521+
//mirostat
514522
if (checkJNum(config, "mirostat")) params.sparams.mirostat = config["mirostat"];
515523
if (checkJNum(config, "mirostat_tau")) params.sparams.mirostat_tau = config["mirostat_tau"];
516524
if (checkJNum(config, "mirostat_eta")) params.sparams.mirostat_eta = config["mirostat_eta"];
517525
//if (config["color"].is_boolean()) params.use_color = config["color"];
526+
// misc
518527
if (config["penalize_nl"].is_boolean()) params.sparams.penalize_nl = config["penalize_nl"];
519528
if (config["use_mmap"].is_boolean()) params.use_mmap = config["use_mmap"];
520529
if (config["flash_attn"].is_boolean()) params.flash_attn = config["flash_attn"];

0 commit comments

Comments
 (0)