@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
161161 params.logit_bias .size (),
162162 params.logit_bias .data ()));
163163
164- llama_sampler_chain_add (result->chain ,
165- llama_sampler_init_penalties (
166- llama_n_vocab (model),
167- llama_token_eos (model),
168- llama_token_nl (model),
169- params.penalty_last_n ,
170- params.penalty_repeat ,
171- params.penalty_freq ,
172- params.penalty_present ,
173- params.penalize_nl ,
174- params.ignore_eos ));
175-
176164 if (params.mirostat == 0 ) {
177165 for (const auto & cnstr : params.samplers ) {
178166 switch (cnstr) {
179- case COMMON_SAMPLER_TYPE_DRY:
167+ case COMMON_SAMPLER_TYPE_DRY:
180168 {
181- std::vector<const char *> c_breakers;
169+ std::vector<const char *> c_breakers;
182170 c_breakers.reserve (params.dry_sequence_breakers .size ());
183- for (const auto & str : params.dry_sequence_breakers ) {
171+ for (const auto & str : params.dry_sequence_breakers ) {
184172 c_breakers.push_back (str.c_str ());
185173 }
186174
187175 llama_sampler_chain_add (result->chain , llama_sampler_init_dry (model, params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
188176 }
189- break ;
177+ break ;
190178 case COMMON_SAMPLER_TYPE_TOP_K:
191179 llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
192180 break ;
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
208196 case COMMON_SAMPLER_TYPE_INFILL:
209197 llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
210198 break ;
199+ case COMMON_SAMPLER_TYPE_PENALTIES:
200+ llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
201+ break ;
211202 default :
212203 GGML_ASSERT (false && " unknown sampler type" );
213204 }
@@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
415406 case COMMON_SAMPLER_TYPE_TEMPERATURE: return ' t' ;
416407 case COMMON_SAMPLER_TYPE_XTC: return ' x' ;
417408 case COMMON_SAMPLER_TYPE_INFILL: return ' i' ;
409+ case COMMON_SAMPLER_TYPE_PENALTIES: return ' e' ;
418410 default : return ' ?' ;
419411 }
420412}
@@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
429421 case COMMON_SAMPLER_TYPE_TEMPERATURE: return " temperature" ;
430422 case COMMON_SAMPLER_TYPE_XTC: return " xtc" ;
431423 case COMMON_SAMPLER_TYPE_INFILL: return " infill" ;
424+ case COMMON_SAMPLER_TYPE_PENALTIES: return " penalties" ;
432425 default : return " " ;
433426 }
434427}
@@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
443436 { " temperature" , COMMON_SAMPLER_TYPE_TEMPERATURE },
444437 { " xtc" , COMMON_SAMPLER_TYPE_XTC },
445438 { " infill" , COMMON_SAMPLER_TYPE_INFILL },
439+ { " penalties" , COMMON_SAMPLER_TYPE_PENALTIES },
446440 };
447441
448442 // since samplers names are written multiple ways
@@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
489483 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
490484 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
491485 { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
486+ { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
492487 };
493488
494489 std::vector<common_sampler_type> samplers;
0 commit comments