Skip to content

Commit 2331c79

Browse files
committed
Implementation of DRY Sampling (post-sampling-refactor)
1 parent c919d5d commit 2331c79

File tree

9 files changed

+738
-74
lines changed

9 files changed

+738
-74
lines changed

common/arg.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,34 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
10091009
params.sparams.penalty_freq = std::stof(value);
10101010
}
10111011
).set_sparam());
1012+
add_opt(llama_arg(
1013+
{"--dry-multiplier"}, "N",
1014+
format("Set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
1015+
[](gpt_params & params, const std::string & value) {
1016+
params.sparams.dry_multiplier = std::stof(value);
1017+
}
1018+
).set_sparam());
1019+
add_opt(llama_arg(
1020+
{"--dry-base"}, "N",
1021+
format("Set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
1022+
[](gpt_params & params, const std::string & value) {
1023+
params.sparams.dry_base = std::stof(value);
1024+
}
1025+
).set_sparam());
1026+
add_opt(llama_arg(
1027+
{"--dry-allowed-length"}, "N",
1028+
format("Set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
1029+
[](gpt_params & params, int value) {
1030+
params.sparams.dry_allowed_length = value;
1031+
}
1032+
).set_sparam());
1033+
add_opt(llama_arg(
1034+
{"--dry-penalty-last-n"}, "N",
1035+
format("Set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
1036+
[](gpt_params & params, int value) {
1037+
params.sparams.dry_penalty_last_n = value;
1038+
}
1039+
).set_sparam());
10121040
add_opt(llama_arg(
10131041
{"--dynatemp-range"}, "N",
10141042
format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),

common/common.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
19751975
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
19761976
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
19771977
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
1978+
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
1979+
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
1980+
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
1981+
fprintf(stream, "dry_penalty_last_n: %d # default: 0 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
19781982
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
19791983
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
19801984
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);

common/common.h

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,33 @@ enum dimre_method {
102102
struct gpt_sampler_params {
103103
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
104104

105-
int32_t n_prev = 64; // number of previous tokens to remember
106-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
107-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
108-
int32_t top_k = 40; // <= 0 to use vocab size
109-
float top_p = 0.95f; // 1.0 = disabled
110-
float min_p = 0.05f; // 0.0 = disabled
111-
float tfs_z = 1.00f; // 1.0 = disabled
112-
float typ_p = 1.00f; // typical_p, 1.0 = disabled
113-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
114-
float dynatemp_range = 0.00f; // 0.0 = disabled
115-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
116-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
117-
float penalty_repeat = 1.00f; // 1.0 = disabled
118-
float penalty_freq = 0.00f; // 0.0 = disabled
119-
float penalty_present = 0.00f; // 0.0 = disabled
120-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
121-
float mirostat_tau = 5.00f; // target entropy
122-
float mirostat_eta = 0.10f; // learning rate
123-
bool penalize_nl = false; // consider newlines as a repeatable token
124-
bool ignore_eos = false;
125-
bool no_perf = false; // disable performance metrics
105+
int32_t n_prev = 64; // number of previous tokens to remember
106+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
107+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
108+
int32_t top_k = 40; // <= 0 to use vocab size
109+
float top_p = 0.95f; // 1.0 = disabled
110+
float min_p = 0.05f; // 0.0 = disabled
111+
float tfs_z = 1.00f; // 1.0 = disabled
112+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
113+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
114+
float dynatemp_range = 0.00f; // 0.0 = disabled
115+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
116+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
117+
float penalty_repeat = 1.00f; // 1.0 = disabled
118+
float penalty_freq = 0.00f; // 0.0 = disabled
119+
float penalty_present = 0.00f; // 0.0 = disabled
120+
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
121+
float dry_base = 1.75f;
122+
int32_t dry_allowed_length = 2;
123+
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
124+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
125+
float mirostat_tau = 5.00f; // target entropy
126+
float mirostat_eta = 0.10f; // learning rate
127+
bool penalize_nl = false; // consider newlines as a repeatable token
128+
bool ignore_eos = false;
129+
bool no_perf = false; // disable performance metrics
130+
131+
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
126132

127133
std::vector<enum gpt_sampler_type> samplers = {
128134
GPT_SAMPLER_TYPE_TOP_K,

common/sampling.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cmath>
66
#include <unordered_map>
7+
#include <cstring>
78

89
// the ring buffer works similarly to std::deque, but with a fixed capacity
910
// TODO: deduplicate with llama-impl.h
@@ -110,6 +111,9 @@ struct gpt_sampler {
110111

111112
llama_token_data_array cur_p;
112113

114+
int32_t n_ctx;
115+
bool context_size_set;
116+
113117
void set_logits(struct llama_context * ctx, int idx) {
114118
const auto * logits = llama_get_logits_ith(ctx, idx);
115119

@@ -126,15 +130,17 @@ struct gpt_sampler {
126130
};
127131

128132
std::string gpt_sampler_params::print() const {
129-
char result[1024];
133+
char result[1536];
130134

131135
snprintf(result, sizeof(result),
132-
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133-
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
134-
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
135-
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
136-
top_k, tfs_z, top_p, min_p, typ_p, temp,
137-
mirostat, mirostat_eta, mirostat_tau);
136+
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
137+
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
138+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
139+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
140+
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
141+
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
142+
top_k, tfs_z, top_p, min_p, typ_p, temp,
143+
mirostat, mirostat_eta, mirostat_tau);
138144

139145
return std::string(result);
140146
}
@@ -151,6 +157,8 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
151157
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
152158
/* .cur = */ {},
153159
/* .cur_p = */ {},
160+
/* .n_ctx = */ 0,
161+
/* .context_size_set = */ false,
154162
};
155163

156164
llama_sampler_chain_add(result->chain,
@@ -171,6 +179,13 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
171179
params.penalize_nl,
172180
params.ignore_eos));
173181

182+
if (params.dry_multiplier != 0.0f && params.dry_base != 0.0f) {
183+
auto * dry_sampler = llama_sampler_init_dry(model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
184+
185+
llama_sampler_dry_set_seq_breakers(dry_sampler, params.dry_sequence_breakers);
186+
llama_sampler_chain_add(result->chain, dry_sampler);
187+
}
188+
174189
if (params.temp > 0.0f) {
175190
if (params.mirostat == 0) {
176191
for (const auto & cnstr : params.samplers) {
@@ -273,6 +288,21 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
273288
}
274289

275290
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
291+
// Check and set the context size if it hasn't been set yet
292+
if (!gsmpl->context_size_set) {
293+
gsmpl->n_ctx = llama_n_ctx(ctx);
294+
gsmpl->context_size_set = true;
295+
296+
// Update the DRY sampler's context size if it is active
297+
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
298+
auto * sampler = llama_sampler_chain_get(gsmpl->chain, i);
299+
if (strcmp(llama_sampler_name(sampler), "dry") == 0) {
300+
llama_sampler_dry_set_context_size(sampler, gsmpl->n_ctx);
301+
break;
302+
}
303+
}
304+
}
305+
276306
gsmpl->set_logits(ctx, idx);
277307

278308
auto & grmr = gsmpl->grmr;

examples/main/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,27 @@ Use the `--no-penalize-nl` option to disable newline penalization when applying
187187

188188
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
189189

190+
### DRY Repetition Penalty
191+
192+
DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns.
193+
194+
- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled).
195+
- `--dry-base N`: Set the DRY sampling base value (default: 1.75).
196+
- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2).
197+
- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size).
198+
199+
The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8.
200+
201+
The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions.
202+
203+
The `dry-allowed-length` option determines the minimum length of repeated sequences that will be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words.
204+
205+
The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context, while 0 disables this aspect of the penalty. Use a positive value to limit the consideration to a specific number of recent tokens.
206+
207+
DRY sampling works alongside traditional repetition penalties to provide more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence.
208+
209+
Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1`
210+
190211
### Top-K Sampling
191212

192213
- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).

examples/server/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ The project is under active development, and we are [looking for feedback and co
114114
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
115115
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
116116
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
117+
| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
118+
| `--dry-base N` | DRY sampling base value (default: 1.75) |
119+
| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) |
120+
| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
117121
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
118122
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
119123
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |
@@ -354,6 +358,16 @@ node index.js
354358

355359
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
356360

361+
`dry_multiplier`: Set the DRY (Don't Repeat Yourself) sampling multiplier. Default: `0.0`, which is disabled.
362+
363+
`dry_base`: Set the DRY sampling base value. Default: `1.75`
364+
365+
`dry_allowed_length`: Set the allowed length for DRY sampling. Default: `2`
366+
367+
`dry_penalty_last_n`: Set DRY penalty for the last n tokens. Default: `-1`, where `0` is disabled and `-1` is context size.
368+
369+
`dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Can be provided as a JSON array of strings or as a JSON-encoded string representing an array of strings. Default: `["\n", ":", "\"", "*"]`
370+
357371
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
358372

359373
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`

0 commit comments

Comments
 (0)