@@ -177,10 +177,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
177177 params.penalize_nl ,
178178 params.ignore_eos ));
179179
180- if (params.temp > 0 .0f ) {
181- if (params.mirostat == 0 ) {
182- for (const auto & cnstr : params.samplers ) {
183- switch (cnstr) {
180+ if (params.mirostat == 0 ) {
181+ for (const auto & cnstr : params.samplers ) {
182+ switch (cnstr) {
184183 case COMMON_SAMPLER_TYPE_DRY:
185184 {
186185 std::vector<const char *> c_breakers;
@@ -192,56 +191,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
192191 llama_sampler_chain_add (result->chain , llama_sampler_init_dry (model, context_size, params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
193192 }
194193 break ;
195- case COMMON_SAMPLER_TYPE_TOP_K:
196- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
197- break ;
198- case COMMON_SAMPLER_TYPE_TOP_P:
199- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
200- break ;
201- case COMMON_SAMPLER_TYPE_MIN_P:
202- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
203- break ;
204- case COMMON_SAMPLER_TYPE_XTC:
205- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
206- break ;
207- case COMMON_SAMPLER_TYPE_TFS_Z:
208- llama_sampler_chain_add (result->chain , llama_sampler_init_tail_free (params.tfs_z , params.min_keep ));
209- break ;
210- case COMMON_SAMPLER_TYPE_TYPICAL_P:
211- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
212- break ;
213- case COMMON_SAMPLER_TYPE_TEMPERATURE:
214- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
215- break ;
216- case COMMON_SAMPLER_TYPE_INFILL:
217- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
218- break ;
219- default :
220- GGML_ASSERT (false && " unknown sampler type" );
221- }
194+ case COMMON_SAMPLER_TYPE_TOP_K:
195+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
196+ break ;
197+ case COMMON_SAMPLER_TYPE_TOP_P:
198+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
199+ break ;
200+ case COMMON_SAMPLER_TYPE_MIN_P:
201+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
202+ break ;
203+ case COMMON_SAMPLER_TYPE_XTC:
204+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
205+ break ;
206+ case COMMON_SAMPLER_TYPE_TFS_Z:
207+ llama_sampler_chain_add (result->chain , llama_sampler_init_tail_free (params.tfs_z , params.min_keep ));
208+ break ;
209+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
210+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
211+ break ;
212+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
213+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
214+ break ;
215+ case COMMON_SAMPLER_TYPE_INFILL:
216+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
217+ break ;
218+ default :
219+ GGML_ASSERT (false && " unknown sampler type" );
222220 }
223- llama_sampler_chain_add (result->chain , llama_sampler_init_softmax ());
224- llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
225- } else if (params.mirostat == 1 ) {
226- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
227- llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat (llama_n_vocab (model), params.seed , params.mirostat_tau , params.mirostat_eta , 100 ));
228- } else if (params.mirostat == 2 ) {
229- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
230- llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat_v2 (params.seed , params.mirostat_tau , params.mirostat_eta ));
231- } else {
232- GGML_ASSERT (false && " unknown mirostat version" );
233221 }
222+ llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
223+ } else if (params.mirostat == 1 ) {
224+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
225+ llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat (llama_n_vocab (model), params.seed , params.mirostat_tau , params.mirostat_eta , 100 ));
226+ } else if (params.mirostat == 2 ) {
227+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
228+ llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat_v2 (params.seed , params.mirostat_tau , params.mirostat_eta ));
234229 } else {
235- if (params.n_probs > 0 ) {
236- // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
237- // ref: https://github.com/ggerganov/llama.cpp/pull/9605
238- //
239- // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
240- // it is much faster, since we avoid sorting all tokens and should give a good approximation
241- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.n_probs ));
242- llama_sampler_chain_add (result->chain , llama_sampler_init_softmax ());
243- }
244- llama_sampler_chain_add (result->chain , llama_sampler_init_greedy ());
230+ GGML_ASSERT (false && " unknown mirostat version" );
245231 }
246232
247233 return result;
0 commit comments