Skip to content

Commit a46f8ac

Browse files
committed
note: also has support for completion tokens count
2 parents aa26a58 + 8f275a7 commit a46f8ac

31 files changed

+152759
-151482
lines changed

common/arg.cpp

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ static void common_params_handle_model_default(common_params & params) {
129129
}
130130
params.hf_file = params.model;
131131
} else if (params.model.empty()) {
132-
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
132+
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
133133
}
134134
} else if (!params.model_url.empty()) {
135135
if (params.model.empty()) {
136-
auto f = string_split(params.model_url, '#').front();
137-
f = string_split(f, '?').front();
138-
params.model = fs_get_cache_file(string_split(f, '/').back());
136+
auto f = string_split<std::string>(params.model_url, '#').front();
137+
f = string_split<std::string>(f, '?').front();
138+
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
139139
}
140140
} else if (params.model.empty()) {
141141
params.model = DEFAULT_MODEL_PATH;
@@ -252,6 +252,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
252252
for (auto & antiprompt : params.antiprompt) {
253253
string_process_escapes(antiprompt);
254254
}
255+
for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
256+
string_process_escapes(seq_breaker);
257+
}
255258
}
256259

257260
if (!params.kv_overrides.empty()) {
@@ -880,7 +883,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
880883
{"--samplers"}, "SAMPLERS",
881884
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
882885
[](common_params & params, const std::string & value) {
883-
const auto sampler_names = string_split(value, ';');
886+
const auto sampler_names = string_split<std::string>(value, ';');
884887
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
885888
}
886889
).set_sparam());
@@ -941,13 +944,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
941944
params.sparams.min_p = std::stof(value);
942945
}
943946
).set_sparam());
944-
add_opt(common_arg(
945-
{"--tfs"}, "N",
946-
string_format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z),
947-
[](common_params & params, const std::string & value) {
948-
params.sparams.tfs_z = std::stof(value);
949-
}
950-
).set_sparam());
951947
add_opt(common_arg(
952948
{"--xtc-probability"}, "N",
953949
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
@@ -998,6 +994,64 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
998994
params.sparams.penalty_freq = std::stof(value);
999995
}
1000996
).set_sparam());
997+
add_opt(common_arg(
998+
{"--dry-multiplier"}, "N",
999+
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
1000+
[](common_params & params, const std::string & value) {
1001+
params.sparams.dry_multiplier = std::stof(value);
1002+
}
1003+
).set_sparam());
1004+
add_opt(common_arg(
1005+
{"--dry-base"}, "N",
1006+
string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
1007+
[](common_params & params, const std::string & value) {
1008+
float potential_base = std::stof(value);
1009+
if (potential_base >= 1.0f)
1010+
{
1011+
params.sparams.dry_base = potential_base;
1012+
}
1013+
}
1014+
).set_sparam());
1015+
add_opt(common_arg(
1016+
{"--dry-allowed-length"}, "N",
1017+
string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
1018+
[](common_params & params, int value) {
1019+
params.sparams.dry_allowed_length = value;
1020+
}
1021+
).set_sparam());
1022+
add_opt(common_arg(
1023+
{"--dry-penalty-last-n"}, "N",
1024+
string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
1025+
[](common_params & params, int value) {
1026+
params.sparams.dry_penalty_last_n = value;
1027+
}
1028+
).set_sparam());
1029+
add_opt(common_arg(
1030+
{"--dry-sequence-breaker"}, "STRING",
1031+
string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
1032+
params.sparams.dry_sequence_breakers.empty() ? "none" :
1033+
std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
1034+
params.sparams.dry_sequence_breakers.end(),
1035+
std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
1036+
[](const std::string& a, const std::string& b) {
1037+
std::string formatted_b = (b == "\n") ? "\\n" : b;
1038+
return a + ", '" + formatted_b + "'";
1039+
}).c_str()),
1040+
[](common_params & params, const std::string & value) {
1041+
static bool defaults_cleared = false;
1042+
1043+
if (!defaults_cleared) {
1044+
params.sparams.dry_sequence_breakers.clear();
1045+
defaults_cleared = true;
1046+
}
1047+
1048+
if (value == "none") {
1049+
params.sparams.dry_sequence_breakers.clear();
1050+
} else {
1051+
params.sparams.dry_sequence_breakers.emplace_back(value);
1052+
}
1053+
}
1054+
).set_sparam());
10011055
add_opt(common_arg(
10021056
{"--dynatemp-range"}, "N",
10031057
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
@@ -1014,7 +1068,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10141068
).set_sparam());
10151069
add_opt(common_arg(
10161070
{"--mirostat"}, "N",
1017-
string_format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
1071+
string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
10181072
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
10191073
[](common_params & params, int value) {
10201074
params.sparams.mirostat = value;

common/common.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -418,19 +418,6 @@ std::string string_format(const char * fmt, ...) {
418418
return std::string(buf.data(), size);
419419
}
420420

421-
std::vector<std::string> string_split(std::string input, char separator) {
422-
std::vector<std::string> parts;
423-
size_t separator_pos = input.find(separator);
424-
while (separator_pos != std::string::npos) {
425-
std::string part = input.substr(0, separator_pos);
426-
parts.emplace_back(part);
427-
input = input.substr(separator_pos + 1);
428-
separator_pos = input.find(separator);
429-
}
430-
parts.emplace_back(input);
431-
return parts;
432-
}
433-
434421
std::string string_strip(const std::string & str) {
435422
size_t start = 0;
436423
size_t end = str.size();
@@ -2021,6 +2008,10 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
20212008
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
20222009
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
20232010
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
2011+
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
2012+
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
2013+
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
2014+
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
20242015
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
20252016
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
20262017
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
@@ -2101,7 +2092,6 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
21012092
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
21022093
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);
21032094

2104-
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
21052095
fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
21062096
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
21072097
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);

common/common.h

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,15 @@ enum llama_example {
8080

8181
enum common_sampler_type {
8282
COMMON_SAMPLER_TYPE_NONE = 0,
83-
COMMON_SAMPLER_TYPE_TOP_K = 1,
84-
COMMON_SAMPLER_TYPE_TOP_P = 2,
85-
COMMON_SAMPLER_TYPE_MIN_P = 3,
86-
COMMON_SAMPLER_TYPE_TFS_Z = 4,
87-
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
88-
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
89-
COMMON_SAMPLER_TYPE_XTC = 7,
90-
COMMON_SAMPLER_TYPE_INFILL = 8,
83+
COMMON_SAMPLER_TYPE_DRY = 1,
84+
COMMON_SAMPLER_TYPE_TOP_K = 2,
85+
COMMON_SAMPLER_TYPE_TOP_P = 3,
86+
COMMON_SAMPLER_TYPE_MIN_P = 4,
87+
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
88+
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
89+
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
90+
COMMON_SAMPLER_TYPE_XTC = 8,
91+
COMMON_SAMPLER_TYPE_INFILL = 9,
9192
};
9293

9394
// dimensionality reduction methods, used by cvector-generator
@@ -100,34 +101,39 @@ enum dimre_method {
100101
struct common_sampler_params {
101102
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
102103

103-
int32_t n_prev = 64; // number of previous tokens to remember
104-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
105-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
106-
int32_t top_k = 40; // <= 0 to use vocab size
107-
float top_p = 0.95f; // 1.0 = disabled
108-
float min_p = 0.05f; // 0.0 = disabled
109-
float xtc_probability = 0.00f; // 0.0 = disabled
110-
float xtc_threshold = 0.10f; // > 0.5 disables XTC
111-
float tfs_z = 1.00f; // 1.0 = disabled
112-
float typ_p = 1.00f; // typical_p, 1.0 = disabled
113-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
114-
float dynatemp_range = 0.00f; // 0.0 = disabled
115-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
116-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
117-
float penalty_repeat = 1.00f; // 1.0 = disabled
118-
float penalty_freq = 0.00f; // 0.0 = disabled
119-
float penalty_present = 0.00f; // 0.0 = disabled
120-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
121-
float mirostat_tau = 5.00f; // target entropy
122-
float mirostat_eta = 0.10f; // learning rate
123-
bool penalize_nl = false; // consider newlines as a repeatable token
124-
bool ignore_eos = false;
125-
bool no_perf = false; // disable performance metrics
104+
int32_t n_prev = 64; // number of previous tokens to remember
105+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
106+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
107+
int32_t top_k = 40; // <= 0 to use vocab size
108+
float top_p = 0.95f; // 1.0 = disabled
109+
float min_p = 0.05f; // 0.0 = disabled
110+
float xtc_probability = 0.00f; // 0.0 = disabled
111+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
112+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
113+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
114+
float dynatemp_range = 0.00f; // 0.0 = disabled
115+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
116+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
117+
float penalty_repeat = 1.00f; // 1.0 = disabled
118+
float penalty_freq = 0.00f; // 0.0 = disabled
119+
float penalty_present = 0.00f; // 0.0 = disabled
120+
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
121+
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
122+
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
123+
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
124+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
125+
float mirostat_tau = 5.00f; // target entropy
126+
float mirostat_eta = 0.10f; // learning rate
127+
bool penalize_nl = false; // consider newlines as a repeatable token
128+
bool ignore_eos = false;
129+
bool no_perf = false; // disable performance metrics
130+
131+
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
126132

127133

128134
std::vector<enum common_sampler_type> samplers = {
135+
COMMON_SAMPLER_TYPE_DRY,
129136
COMMON_SAMPLER_TYPE_TOP_K,
130-
COMMON_SAMPLER_TYPE_TFS_Z,
131137
COMMON_SAMPLER_TYPE_TYPICAL_P,
132138
COMMON_SAMPLER_TYPE_TOP_P,
133139
COMMON_SAMPLER_TYPE_MIN_P,
@@ -376,15 +382,14 @@ bool set_process_priority(enum ggml_sched_priority prio);
376382
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
377383
std::string string_format(const char * fmt, ...);
378384

379-
std::vector<std::string> string_split(std::string input, char separator);
380-
381385
std::string string_strip(const std::string & str);
382386
std::string string_get_sortable_timestamp();
383387

384388
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
385389

386390
template<class T>
387391
static std::vector<T> string_split(const std::string & str, char delim) {
392+
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
388393
std::vector<T> values;
389394
std::istringstream str_stream(str);
390395
std::string token;
@@ -397,6 +402,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
397402
return values;
398403
}
399404

405+
template<>
406+
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
407+
{
408+
std::vector<std::string> parts;
409+
size_t begin_pos = 0;
410+
size_t separator_pos = input.find(separator);
411+
while (separator_pos != std::string::npos) {
412+
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
413+
parts.emplace_back(part);
414+
begin_pos = separator_pos + 1;
415+
separator_pos = input.find(separator, begin_pos);
416+
}
417+
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
418+
return parts;
419+
}
420+
400421
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
401422
void string_process_escapes(std::string & input);
402423

0 commit comments

Comments
 (0)