Skip to content

Commit 877566d

Browse files
authored
llama: introduce support for model-embedded sampling parameters (ggml-org#17120)
1 parent 3d07caa commit 877566d

File tree

10 files changed

+293
-13
lines changed

10 files changed

+293
-13
lines changed

common/arg.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12321232
[](common_params & params, const std::string & value) {
12331233
const auto sampler_names = string_split<std::string>(value, ';');
12341234
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
1235+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
12351236
}
12361237
).set_sparam());
12371238
add_opt(common_arg(
@@ -1261,27 +1262,31 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12611262
[](common_params & params, const std::string & value) {
12621263
params.sampling.temp = std::stof(value);
12631264
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
1265+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
12641266
}
12651267
).set_sparam());
12661268
add_opt(common_arg(
12671269
{"--top-k"}, "N",
12681270
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
12691271
[](common_params & params, int value) {
12701272
params.sampling.top_k = value;
1273+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
12711274
}
12721275
).set_sparam());
12731276
add_opt(common_arg(
12741277
{"--top-p"}, "N",
12751278
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
12761279
[](common_params & params, const std::string & value) {
12771280
params.sampling.top_p = std::stof(value);
1281+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
12781282
}
12791283
).set_sparam());
12801284
add_opt(common_arg(
12811285
{"--min-p"}, "N",
12821286
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
12831287
[](common_params & params, const std::string & value) {
12841288
params.sampling.min_p = std::stof(value);
1289+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
12851290
}
12861291
).set_sparam());
12871292
add_opt(common_arg(
@@ -1296,13 +1301,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12961301
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
12971302
[](common_params & params, const std::string & value) {
12981303
params.sampling.xtc_probability = std::stof(value);
1304+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
12991305
}
13001306
).set_sparam());
13011307
add_opt(common_arg(
13021308
{"--xtc-threshold"}, "N",
13031309
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
13041310
[](common_params & params, const std::string & value) {
13051311
params.sampling.xtc_threshold = std::stof(value);
1312+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
13061313
}
13071314
).set_sparam());
13081315
add_opt(common_arg(
@@ -1321,13 +1328,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
13211328
}
13221329
params.sampling.penalty_last_n = value;
13231330
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
1331+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
13241332
}
13251333
).set_sparam());
13261334
add_opt(common_arg(
13271335
{"--repeat-penalty"}, "N",
13281336
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
13291337
[](common_params & params, const std::string & value) {
13301338
params.sampling.penalty_repeat = std::stof(value);
1339+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
13311340
}
13321341
).set_sparam());
13331342
add_opt(common_arg(
@@ -1425,20 +1434,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14251434
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
14261435
[](common_params & params, int value) {
14271436
params.sampling.mirostat = value;
1437+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
14281438
}
14291439
).set_sparam());
14301440
add_opt(common_arg(
14311441
{"--mirostat-lr"}, "N",
14321442
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
14331443
[](common_params & params, const std::string & value) {
14341444
params.sampling.mirostat_eta = std::stof(value);
1445+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
14351446
}
14361447
).set_sparam());
14371448
add_opt(common_arg(
14381449
{"--mirostat-ent"}, "N",
14391450
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
14401451
[](common_params & params, const std::string & value) {
14411452
params.sampling.mirostat_tau = std::stof(value);
1453+
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
14421454
}
14431455
).set_sparam());
14441456
add_opt(common_arg(

common/common.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "common.h"
99
#include "log.h"
1010
#include "llama.h"
11+
#include "sampling.h"
1112

1213
#include <algorithm>
1314
#include <cinttypes>
@@ -949,6 +950,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
949950
// Model utils
950951
//
951952

953+
static inline void common_init_sampler_from_model(
954+
const llama_model * model,
955+
common_params_sampling & sparams) {
956+
957+
const uint64_t config = sparams.user_sampling_config;
958+
959+
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
960+
if (config & user_config) return;
961+
962+
char buf[64] = {0};
963+
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
964+
char * end = nullptr;
965+
int32_t v = strtol(buf, &end, 10);
966+
if (end && end != buf) dst = v;
967+
}
968+
};
969+
970+
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
971+
if (config & user_config) return;
972+
973+
char buf[128] = {0};
974+
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
975+
char * end = nullptr;
976+
float v = strtof(buf, &end);
977+
if (end && end != buf) dst = v;
978+
}
979+
};
980+
981+
// Sampling sequence
982+
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
983+
char buf[512] = {0};
984+
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
985+
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
986+
if (!sampler_names.empty()) {
987+
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
988+
}
989+
}
990+
}
991+
992+
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
993+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
994+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
995+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
996+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
997+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
998+
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
999+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
1000+
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
1001+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
1002+
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
1003+
}
1004+
9521005
struct common_init_result common_init_from_params(common_params & params) {
9531006
common_init_result iparams;
9541007
auto mparams = common_model_params_to_llama(params);
@@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) {
9601013
return iparams;
9611014
}
9621015

1016+
common_init_sampler_from_model(model, params.sampling);
1017+
9631018
const llama_vocab * vocab = llama_model_get_vocab(model);
9641019

9651020
auto cparams = common_context_params_to_llama(params);

common/common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,22 @@ struct common_grammar_trigger {
140140
llama_token token = LLAMA_TOKEN_NULL;
141141
};
142142

143+
enum common_params_sampling_config : uint64_t {
144+
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
145+
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
146+
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
147+
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
148+
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
149+
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
150+
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
151+
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
152+
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
153+
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
154+
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
155+
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
156+
};
157+
158+
143159
// sampling parameters
144160
struct common_params_sampling {
145161
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -172,6 +188,8 @@ struct common_params_sampling {
172188
bool no_perf = false; // disable performance metrics
173189
bool timing_per_token = false;
174190

191+
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
192+
175193
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
176194

177195

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,20 @@ class General:
2525
ALIGNMENT = "general.alignment"
2626
FILE_TYPE = "general.file_type"
2727

28+
# Recommended Sampler Parameters
29+
SAMPLING_SEQUENCE = "general.sampling.sequence"
30+
SAMPLING_TOP_K = "general.sampling.top_k"
31+
SAMPLING_TOP_P = "general.sampling.top_p"
32+
SAMPLING_MIN_P = "general.sampling.min_p"
33+
SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
34+
SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
35+
SAMPLING_TEMP = "general.sampling.temp"
36+
SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
37+
SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
38+
SAMPLING_MIROSTAT = "general.sampling.mirostat"
39+
SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
40+
SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
41+
2842
# Authorship Metadata
2943
NAME = "general.name"
3044
AUTHOR = "general.author"

gguf-py/gguf/gguf_writer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,42 @@ def add_custom_alignment(self, alignment: int) -> None:
496496
def add_file_type(self, ftype: int) -> None:
497497
self.add_uint32(Keys.General.FILE_TYPE, ftype)
498498

499+
def add_sampling_sequence(self, sequence: str) -> None:
500+
self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
501+
502+
def add_sampling_top_k(self, top_k: int) -> None:
503+
self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
504+
505+
def add_sampling_top_p(self, top_p: float) -> None:
506+
self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
507+
508+
def add_sampling_min_p(self, min_p: float) -> None:
509+
self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
510+
511+
def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
512+
self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
513+
514+
def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
515+
self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
516+
517+
def add_sampling_temp(self, temp: float) -> None:
518+
self.add_float32(Keys.General.SAMPLING_TEMP, temp)
519+
520+
def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
521+
self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
522+
523+
def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
524+
self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
525+
526+
def add_sampling_mirostat(self, mirostat: int) -> None:
527+
self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
528+
529+
def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
530+
self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
531+
532+
def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
533+
self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
534+
499535
def add_name(self, name: str) -> None:
500536
self.add_string(Keys.General.NAME, name)
501537

0 commit comments

Comments
 (0)