Skip to content

Commit 5ccb05b

Browse files
committed
Define COMMON_SAMPLER_TYPE_TOP_N_SIGMA
1 parent ca992ad commit 5ccb05b

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ enum common_sampler_type {
9696
COMMON_SAMPLER_TYPE_XTC = 8,
9797
COMMON_SAMPLER_TYPE_INFILL = 9,
9898
COMMON_SAMPLER_TYPE_PENALTIES = 10,
99+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
99100
};
100101

101102
// dimensionality reduction methods, used by cvector-generator
@@ -153,6 +154,7 @@ struct common_params_sampling {
153154
std::vector<enum common_sampler_type> samplers = {
154155
COMMON_SAMPLER_TYPE_PENALTIES,
155156
COMMON_SAMPLER_TYPE_DRY,
157+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
156158
COMMON_SAMPLER_TYPE_TOP_K,
157159
COMMON_SAMPLER_TYPE_TYPICAL_P,
158160
COMMON_SAMPLER_TYPE_TOP_P,

common/sampling.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
207207
case COMMON_SAMPLER_TYPE_TOP_P:
208208
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
209209
break;
210+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
211+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
212+
break;
210213
case COMMON_SAMPLER_TYPE_MIN_P:
211214
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
212215
break;
@@ -219,9 +222,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
219222
case COMMON_SAMPLER_TYPE_TEMPERATURE:
220223
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
221224
break;
222-
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
223-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
224-
break;
225225
case COMMON_SAMPLER_TYPE_INFILL:
226226
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
227227
break;
@@ -431,6 +431,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
431431
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
432432
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
433433
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
434+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
434435
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
435436
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
436437
case COMMON_SAMPLER_TYPE_XTC: return 'x';
@@ -446,6 +447,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
446447
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
447448
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
448449
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
450+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
449451
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
450452
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
451453
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@@ -460,6 +462,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
460462
{ "dry", COMMON_SAMPLER_TYPE_DRY },
461463
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
462464
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
465+
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
463466
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
464467
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
465468
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -473,6 +476,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
473476
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
474477
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
475478
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
479+
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
476480
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
477481
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
478482
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -508,6 +512,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
508512
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
509513
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
510514
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
515+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
511516
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
512517
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
513518
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },

0 commit comments

Comments
 (0)