@@ -188,63 +188,64 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
188188 params.logit_bias .data ()));
189189
190190 if (params.mirostat == 0 ) {
191- for (const auto & cnstr : params.samplers ) {
192- switch (cnstr) {
193- case COMMON_SAMPLER_TYPE_DRY:
194- {
195- std::vector<const char *> c_breakers;
196- c_breakers.reserve (params.dry_sequence_breakers .size ());
197- for (const auto & str : params.dry_sequence_breakers ) {
198- c_breakers.push_back (str.c_str ());
191+ if (params.top_n_sigma >= 0 ) {
192+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
193+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
194+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
195+ } else {
196+ for (const auto & cnstr : params.samplers ) {
197+ switch (cnstr) {
198+ case COMMON_SAMPLER_TYPE_DRY:
199+ {
200+ std::vector<const char *> c_breakers;
201+ c_breakers.reserve (params.dry_sequence_breakers .size ());
202+ for (const auto & str : params.dry_sequence_breakers ) {
203+ c_breakers.push_back (str.c_str ());
204+ }
205+
206+ llama_sampler_chain_add (
207+ result->chain , llama_sampler_init_dry (
208+ vocab, llama_model_n_ctx_train (model), params.dry_multiplier ,
209+ params.dry_base , params.dry_allowed_length ,
210+ params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
199211 }
200-
212+ break ;
213+ case COMMON_SAMPLER_TYPE_TOP_K:
214+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
215+ break ;
216+ case COMMON_SAMPLER_TYPE_TOP_P:
217+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
218+ break ;
219+ case COMMON_SAMPLER_TYPE_MIN_P:
220+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
221+ break ;
222+ case COMMON_SAMPLER_TYPE_XTC:
223+ llama_sampler_chain_add (result->chain ,
224+ llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold ,
225+ params.min_keep , params.seed ));
226+ break ;
227+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
228+ llama_sampler_chain_add (result->chain ,
229+ llama_sampler_init_typical (params.typ_p , params.min_keep ));
230+ break ;
231+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
201232 llama_sampler_chain_add (
202233 result->chain ,
203- llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier ,
204- params.dry_base , params.dry_allowed_length ,
205- params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
206- }
207- break ;
208- case COMMON_SAMPLER_TYPE_TOP_K:
209- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
210- break ;
211- case COMMON_SAMPLER_TYPE_TOP_P:
212- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
213- break ;
214- case COMMON_SAMPLER_TYPE_MIN_P:
215- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
216- break ;
217- case COMMON_SAMPLER_TYPE_N_SIGMA:
218- if (params.top_n_sigma >= 0 ) {
219- llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
220- }
221- break ;
222- case COMMON_SAMPLER_TYPE_XTC:
223- llama_sampler_chain_add (result->chain ,
224- llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold ,
225- params.min_keep , params.seed ));
226- break ;
227- case COMMON_SAMPLER_TYPE_TYPICAL_P:
228- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
229- break ;
230- case COMMON_SAMPLER_TYPE_TEMPERATURE:
231- llama_sampler_chain_add (
232- result->chain ,
233- llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
234- break ;
235- case COMMON_SAMPLER_TYPE_INFILL:
236- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
237- break ;
238- case COMMON_SAMPLER_TYPE_PENALTIES:
239- llama_sampler_chain_add (result->chain ,
240- llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat ,
241- params.penalty_freq , params.penalty_present ));
242- break ;
243- default :
244- GGML_ASSERT (false && " unknown sampler type" );
234+ llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
235+ break ;
236+ case COMMON_SAMPLER_TYPE_INFILL:
237+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
238+ break ;
239+ case COMMON_SAMPLER_TYPE_PENALTIES:
240+ llama_sampler_chain_add (
241+ result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat ,
242+ params.penalty_freq , params.penalty_present ));
243+ break ;
244+ default :
245+ GGML_ASSERT (false && " unknown sampler type" );
246+ }
245247 }
246248 }
247-
248249 llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
249250 } else if (params.mirostat == 1 ) {
250251 llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
0 commit comments