Skip to content

Commit 875ff55

Browse files
committed
DRY: fixes, adjustments from code review
1 parent c210cba commit 875ff55

File tree

9 files changed

+149
-45
lines changed

9 files changed

+149
-45
lines changed

common/common.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ enum llama_example {
8484

8585
enum common_sampler_type {
8686
COMMON_SAMPLER_TYPE_NONE = 0,
87-
COMMON_SAMPLER_TYPE_TOP_K = 1,
88-
COMMON_SAMPLER_TYPE_TOP_P = 2,
89-
COMMON_SAMPLER_TYPE_MIN_P = 3,
90-
COMMON_SAMPLER_TYPE_TFS_Z = 4,
91-
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
92-
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
93-
COMMON_SAMPLER_TYPE_XTC = 7,
94-
COMMON_SAMPLER_TYPE_INFILL = 8,
87+
COMMON_SAMPLER_TYPE_DRY = 1,
88+
COMMON_SAMPLER_TYPE_TOP_K = 2,
89+
COMMON_SAMPLER_TYPE_TOP_P = 3,
90+
COMMON_SAMPLER_TYPE_MIN_P = 4,
91+
COMMON_SAMPLER_TYPE_TFS_Z = 5,
92+
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
93+
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
94+
COMMON_SAMPLER_TYPE_XTC = 8,
95+
COMMON_SAMPLER_TYPE_INFILL = 9,
9596
};
9697

9798
// dimensionality reduction methods, used by cvector-generator
@@ -136,6 +137,7 @@ struct common_sampler_params {
136137

137138

138139
std::vector<enum common_sampler_type> samplers = {
140+
COMMON_SAMPLER_TYPE_DRY,
139141
COMMON_SAMPLER_TYPE_TOP_K,
140142
COMMON_SAMPLER_TYPE_TFS_Z,
141143
COMMON_SAMPLER_TYPE_TYPICAL_P,

common/sampling.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
#include "common.h"
44

5+
#include "log.h"
6+
57
#include <cmath>
68
#include <unordered_map>
79

10+
extern void llama_sampler_dry_set_seq_breakers(struct llama_sampler * smpl, const std::vector<std::string>& seq_breakers);
11+
812
// the ring buffer works similarly to std::deque, but with a fixed capacity
913
// TODO: deduplicate with llama-impl.h
1014
template<typename T>
@@ -98,6 +102,8 @@ struct ring_buffer {
98102
std::vector<T> data;
99103
};
100104

105+
106+
101107
struct common_sampler {
102108
common_sampler_params params;
103109

@@ -173,17 +179,19 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
173179
params.penalize_nl,
174180
params.ignore_eos));
175181

176-
if (params.dry_multiplier != 0.0f && params.dry_base != 0.0f) {
177-
auto * dry_sampler = llama_sampler_init_dry(model, context_size, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
178-
179-
llama_sampler_dry_set_seq_breakers(dry_sampler, params.dry_sequence_breakers);
180-
llama_sampler_chain_add(result->chain, dry_sampler);
181-
}
182+
struct llama_sampler * dry_sampler = nullptr;
182183

183184
if (params.temp > 0.0f) {
184185
if (params.mirostat == 0) {
185186
for (const auto & cnstr : params.samplers) {
186187
switch (cnstr) {
188+
case COMMON_SAMPLER_TYPE_DRY:
189+
dry_sampler = llama_sampler_init_dry(model, context_size, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
190+
if (dry_sampler != nullptr) {
191+
llama_sampler_dry_set_seq_breakers(dry_sampler, params.dry_sequence_breakers);
192+
llama_sampler_chain_add(result->chain, dry_sampler);
193+
}
194+
break;
187195
case COMMON_SAMPLER_TYPE_TOP_K:
188196
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
189197
break;
@@ -236,6 +244,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
236244
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
237245
}
238246

247+
// // If DRY sampler wasn't added to the chain, free it
248+
// if (dry_sampler) {
249+
// llama_sampler_free(dry_sampler);
250+
// }
251+
239252
return result;
240253
}
241254

@@ -381,6 +394,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
381394

382395
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
383396
switch (cnstr) {
397+
case COMMON_SAMPLER_TYPE_DRY: return 'd';
384398
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
385399
case COMMON_SAMPLER_TYPE_TFS_Z: return 'f';
386400
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
@@ -395,6 +409,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
395409

396410
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
397411
switch (cnstr) {
412+
case COMMON_SAMPLER_TYPE_DRY: return "dry";
398413
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
399414
case COMMON_SAMPLER_TYPE_TFS_Z: return "tfs_z";
400415
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
@@ -409,6 +424,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
409424

410425
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
411426
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
427+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
412428
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
413429
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
414430
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -457,6 +473,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
457473

458474
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
459475
std::unordered_map<char, common_sampler_type> sampler_name_map = {
476+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
460477
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
461478
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TFS_Z), COMMON_SAMPLER_TYPE_TFS_Z },
462479
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },

examples/server/public/index-new.html

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
repeat_last_n: 0, // 0 = disable penalty, -1 = context size
4141
repeat_penalty: 1.0, // 1.0 = disabled
4242
penalize_nl: false, // true only useful for infinite completion
43+
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
44+
dry_base: 1.75, // 0.0 = disabled
45+
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
46+
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
4347
top_k: 0, // <= 0 to use vocab size
4448
top_p: 1.0, // 1.0 = disabled
4549
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
@@ -833,15 +837,19 @@
833837
<fieldset class="params">
834838
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
835839
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
836-
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
837840
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
838-
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
839841
${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
842+
${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
840843
${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
841844
${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
842845
${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
846+
${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which is disabled.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
847+
${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
848+
${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
849+
${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
850+
${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })}
843851
${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
844-
</fieldset>
852+
</fieldset>
845853
846854
<hr style="height: 1px; background-color: #ececf1; border: none;" />
847855
@@ -1144,6 +1152,8 @@ <h2>llama.cpp</h2>
11441152
repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 },
11451153
presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
11461154
frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 },
1155+
dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 },
1156+
dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 },
11471157
};
11481158
// add an event listener for each slider
11491159
Object.keys(snapSettings).forEach(sliderName => {

examples/server/public/index.html

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@
304304
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
305305
repeat_penalty: 1.18, // 1.0 = disabled
306306
penalize_nl: false,
307+
dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
308+
dry_base: 1.75, // 0.0 = disabled
309+
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
310+
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
307311
top_k: 40, // <= 0 to use vocab size
308312
top_p: 0.95, // 1.0 = disabled
309313
min_p: 0.05, // 0 = disabled
@@ -1015,6 +1019,10 @@
10151019
${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })}
10161020
${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
10171021
${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })}
1022+
${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })}
1023+
${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })}
1024+
${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })}
1025+
${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })}
10181026
${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })}
10191027
${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}
10201028
</fieldset>

examples/server/public/style.css

100755100644
File mode changed.

examples/server/server.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,8 @@ struct server_context {
863863
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
864864
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
865865
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
866-
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
867-
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
866+
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
867+
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
868868
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
869869
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
870870
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
@@ -887,8 +887,8 @@ struct server_context {
887887
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
888888
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
889889
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
890-
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
891-
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
890+
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
891+
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
892892

893893
// sequence breakers for DRY
894894
{
@@ -2170,7 +2170,7 @@ struct server_context {
21702170
}
21712171

21722172
// Should this be (re-)moved?
2173-
common_sampler_reset(slot.smpl);
2173+
//common_sampler_reset(slot.smpl);
21742174

21752175
if (slot.params.cache_prompt) {
21762176
// reuse any previously computed tokens that are common with the new prompt
@@ -2269,7 +2269,7 @@ struct server_context {
22692269
// there is no common part left
22702270
slot.n_past = 0;
22712271

2272-
common_sampler_reset(slot.smpl);
2272+
//common_sampler_reset(slot.smpl);
22732273
}
22742274

22752275
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
@@ -2297,6 +2297,8 @@ struct server_context {
22972297

22982298
GGML_ASSERT(batch.n_tokens > 0);
22992299

2300+
common_sampler_reset(slot.smpl);
2301+
23002302
// Process all prompt tokens through sampler system
23012303
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
23022304
common_sampler_accept(slot.smpl, prompt_tokens[i], false);

include/llama.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,11 @@ extern "C" {
11581158
int32_t dry_allowed_length,
11591159
int32_t dry_penalty_last_n);
11601160

1161+
LLAMA_API void llama_sampler_dry_set_seq_breakers_c(
1162+
struct llama_sampler * smpl,
1163+
const char ** seq_breakers,
1164+
int num_breakers);
1165+
11611166
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
11621167
int32_t n_vocab,
11631168
int32_t n_logit_bias,
@@ -1262,15 +1267,4 @@ extern "C" {
12621267
}
12631268
#endif
12641269

1265-
// Need to find a cleaner way to implement the sequence breakers as a vector of strings
1266-
#ifdef __cplusplus
1267-
1268-
#include <vector>
1269-
#include <string>
1270-
1271-
LLAMA_API void llama_sampler_dry_set_seq_breakers(struct llama_sampler * sampler, const std::vector<std::string>& seq_breakers);
1272-
LLAMA_API void llama_sampler_dry_set_seq_breakers_as_tokens(struct llama_sampler * smpl, const std::vector<std::vector<llama_token>>& seq_breakers);
1273-
1274-
#endif // __cplusplus
1275-
1276-
#endif // LLAMA_H
1270+
#endif // LLAMA_H

0 commit comments

Comments
 (0)