@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
229229 params.logit_bias .data ()));
230230
231231 if (params.mirostat == 0 ) {
232- if (params.top_n_sigma >= 0 ) {
233- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
234- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
235- llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
236- } else {
237- for (const auto & cnstr : params.samplers ) {
238- switch (cnstr) {
239- case COMMON_SAMPLER_TYPE_DRY:
240- {
241- std::vector<const char *> c_breakers;
242- c_breakers.reserve (params.dry_sequence_breakers .size ());
243- for (const auto & str : params.dry_sequence_breakers ) {
244- c_breakers.push_back (str.c_str ());
245- }
246-
247- llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
232+ for (const auto & cnstr : params.samplers ) {
233+ switch (cnstr) {
234+ case COMMON_SAMPLER_TYPE_DRY:
235+ {
236+ std::vector<const char *> c_breakers;
237+ c_breakers.reserve (params.dry_sequence_breakers .size ());
238+ for (const auto & str : params.dry_sequence_breakers ) {
239+ c_breakers.push_back (str.c_str ());
248240 }
249- break ;
250- case COMMON_SAMPLER_TYPE_TOP_K:
251- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
252- break ;
253- case COMMON_SAMPLER_TYPE_TOP_P:
254- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
255- break ;
256- case COMMON_SAMPLER_TYPE_MIN_P:
257- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
258- break ;
259- case COMMON_SAMPLER_TYPE_XTC:
260- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
261- break ;
262- case COMMON_SAMPLER_TYPE_TYPICAL_P:
263- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
264- break ;
265- case COMMON_SAMPLER_TYPE_TEMPERATURE:
266- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
267- break ;
268- case COMMON_SAMPLER_TYPE_INFILL:
269- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
270- break ;
271- case COMMON_SAMPLER_TYPE_PENALTIES:
272- llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
273- break ;
274- default :
275- GGML_ASSERT (false && " unknown sampler type" );
276- }
241+
242+ llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
243+ }
244+ break ;
245+ case COMMON_SAMPLER_TYPE_TOP_K:
246+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
247+ break ;
248+ case COMMON_SAMPLER_TYPE_TOP_P:
249+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
250+ break ;
251+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
252+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
253+ break ;
254+ case COMMON_SAMPLER_TYPE_MIN_P:
255+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
256+ break ;
257+ case COMMON_SAMPLER_TYPE_XTC:
258+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
259+ break ;
260+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
261+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
262+ break ;
263+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
264+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
265+ break ;
266+ case COMMON_SAMPLER_TYPE_INFILL:
267+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
268+ break ;
269+ case COMMON_SAMPLER_TYPE_PENALTIES:
270+ llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
271+ break ;
272+ default :
273+ GGML_ASSERT (false && " unknown sampler type" );
277274 }
278275 }
279276 llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
@@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
475472 case COMMON_SAMPLER_TYPE_TOP_K: return ' k' ;
476473 case COMMON_SAMPLER_TYPE_TYPICAL_P: return ' y' ;
477474 case COMMON_SAMPLER_TYPE_TOP_P: return ' p' ;
475+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return ' s' ;
478476 case COMMON_SAMPLER_TYPE_MIN_P: return ' m' ;
479477 case COMMON_SAMPLER_TYPE_TEMPERATURE: return ' t' ;
480478 case COMMON_SAMPLER_TYPE_XTC: return ' x' ;
@@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
490488 case COMMON_SAMPLER_TYPE_TOP_K: return " top_k" ;
491489 case COMMON_SAMPLER_TYPE_TYPICAL_P: return " typ_p" ;
492490 case COMMON_SAMPLER_TYPE_TOP_P: return " top_p" ;
491+ case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return " top_n_sigma" ;
493492 case COMMON_SAMPLER_TYPE_MIN_P: return " min_p" ;
494493 case COMMON_SAMPLER_TYPE_TEMPERATURE: return " temperature" ;
495494 case COMMON_SAMPLER_TYPE_XTC: return " xtc" ;
@@ -504,6 +503,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
504503 { " dry" , COMMON_SAMPLER_TYPE_DRY },
505504 { " top_k" , COMMON_SAMPLER_TYPE_TOP_K },
506505 { " top_p" , COMMON_SAMPLER_TYPE_TOP_P },
506+ { " top_n_sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
507507 { " typ_p" , COMMON_SAMPLER_TYPE_TYPICAL_P },
508508 { " min_p" , COMMON_SAMPLER_TYPE_MIN_P },
509509 { " temperature" , COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -517,6 +517,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
517517 std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
518518 { " top-k" , COMMON_SAMPLER_TYPE_TOP_K },
519519 { " top-p" , COMMON_SAMPLER_TYPE_TOP_P },
520+ { " top-n-sigma" , COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
520521 { " nucleus" , COMMON_SAMPLER_TYPE_TOP_P },
521522 { " typical-p" , COMMON_SAMPLER_TYPE_TYPICAL_P },
522523 { " typical" , COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -552,6 +553,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
552553 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
553554 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
554555 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
556+ { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
555557 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
556558 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
557559 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
0 commit comments