@@ -171,60 +171,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
171171 params.penalize_nl ,
172172 params.ignore_eos ));
173173
174- if (params.temp > 0 .0f ) {
175- if (params.mirostat == 0 ) {
176- for (const auto & cnstr : params.samplers ) {
177- switch (cnstr) {
178- case COMMON_SAMPLER_TYPE_TOP_K:
179- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
180- break ;
181- case COMMON_SAMPLER_TYPE_TOP_P:
182- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
183- break ;
184- case COMMON_SAMPLER_TYPE_MIN_P:
185- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
186- break ;
187- case COMMON_SAMPLER_TYPE_XTC:
188- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
189- break ;
190- case COMMON_SAMPLER_TYPE_TFS_Z:
191- llama_sampler_chain_add (result->chain , llama_sampler_init_tail_free (params.tfs_z , params.min_keep ));
192- break ;
193- case COMMON_SAMPLER_TYPE_TYPICAL_P:
194- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
195- break ;
196- case COMMON_SAMPLER_TYPE_TEMPERATURE:
197- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
198- break ;
199- case COMMON_SAMPLER_TYPE_INFILL:
200- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
201- break ;
202- default :
203- GGML_ASSERT (false && " unknown sampler type" );
204- }
174+ if (params.mirostat == 0 ) {
175+ for (const auto & cnstr : params.samplers ) {
176+ switch (cnstr) {
177+ case COMMON_SAMPLER_TYPE_TOP_K:
178+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
179+ break ;
180+ case COMMON_SAMPLER_TYPE_TOP_P:
181+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
182+ break ;
183+ case COMMON_SAMPLER_TYPE_MIN_P:
184+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
185+ break ;
186+ case COMMON_SAMPLER_TYPE_XTC:
187+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
188+ break ;
189+ case COMMON_SAMPLER_TYPE_TFS_Z:
190+ llama_sampler_chain_add (result->chain , llama_sampler_init_tail_free (params.tfs_z , params.min_keep ));
191+ break ;
192+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
193+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
194+ break ;
195+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
196+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
197+ break ;
198+ case COMMON_SAMPLER_TYPE_INFILL:
199+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (model));
200+ break ;
201+ default :
202+ GGML_ASSERT (false && " unknown sampler type" );
205203 }
206- llama_sampler_chain_add (result->chain , llama_sampler_init_softmax ());
207- llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
208- } else if (params.mirostat == 1 ) {
209- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
210- llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat (llama_n_vocab (model), params.seed , params.mirostat_tau , params.mirostat_eta , 100 ));
211- } else if (params.mirostat == 2 ) {
212- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
213- llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat_v2 (params.seed , params.mirostat_tau , params.mirostat_eta ));
214- } else {
215- GGML_ASSERT (false && " unknown mirostat version" );
216204 }
205+ llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
206+ } else if (params.mirostat == 1 ) {
207+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
208+ llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat (llama_n_vocab (model), params.seed , params.mirostat_tau , params.mirostat_eta , 100 ));
209+ } else if (params.mirostat == 2 ) {
210+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
211+ llama_sampler_chain_add (result->chain , llama_sampler_init_mirostat_v2 (params.seed , params.mirostat_tau , params.mirostat_eta ));
217212 } else {
218- if (params.n_probs > 0 ) {
219- // some use cases require to sample greedily, but still obtain the probabilities of the top tokens
220- // ref: https://github.com/ggerganov/llama.cpp/pull/9605
221- //
222- // the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
223- // it is much faster, since we avoid sorting all tokens and should give a good approximation
224- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.n_probs ));
225- llama_sampler_chain_add (result->chain , llama_sampler_init_softmax ());
226- }
227- llama_sampler_chain_add (result->chain , llama_sampler_init_greedy ());
213+ GGML_ASSERT (false && " unknown mirostat version" );
228214 }
229215
230216 return result;
0 commit comments