Skip to content

Commit ff8b612

Browse files
committed
revert: sampler ordering
1 parent 1bdb603 commit ff8b612

File tree

1 file changed

+53
-52
lines changed

1 file changed

+53
-52
lines changed

common/sampling.cpp

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -188,63 +188,64 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
188188
params.logit_bias.data()));
189189

190190
if (params.mirostat == 0) {
191-
for (const auto & cnstr : params.samplers) {
192-
switch (cnstr) {
193-
case COMMON_SAMPLER_TYPE_DRY:
194-
{
195-
std::vector<const char *> c_breakers;
196-
c_breakers.reserve(params.dry_sequence_breakers.size());
197-
for (const auto & str : params.dry_sequence_breakers) {
198-
c_breakers.push_back(str.c_str());
191+
if (params.top_n_sigma >= 0) {
192+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.top_k));
193+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
194+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
195+
} else {
196+
for (const auto & cnstr : params.samplers) {
197+
switch (cnstr) {
198+
case COMMON_SAMPLER_TYPE_DRY:
199+
{
200+
std::vector<const char *> c_breakers;
201+
c_breakers.reserve(params.dry_sequence_breakers.size());
202+
for (const auto & str : params.dry_sequence_breakers) {
203+
c_breakers.push_back(str.c_str());
204+
}
205+
206+
llama_sampler_chain_add(
207+
result->chain, llama_sampler_init_dry(
208+
vocab, llama_model_n_ctx_train(model), params.dry_multiplier,
209+
params.dry_base, params.dry_allowed_length,
210+
params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
199211
}
200-
212+
break;
213+
case COMMON_SAMPLER_TYPE_TOP_K:
214+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.top_k));
215+
break;
216+
case COMMON_SAMPLER_TYPE_TOP_P:
217+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p(params.top_p, params.min_keep));
218+
break;
219+
case COMMON_SAMPLER_TYPE_MIN_P:
220+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p(params.min_p, params.min_keep));
221+
break;
222+
case COMMON_SAMPLER_TYPE_XTC:
223+
llama_sampler_chain_add(result->chain,
224+
llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold,
225+
params.min_keep, params.seed));
226+
break;
227+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
228+
llama_sampler_chain_add(result->chain,
229+
llama_sampler_init_typical(params.typ_p, params.min_keep));
230+
break;
231+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
201232
llama_sampler_chain_add(
202233
result->chain,
203-
llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier,
204-
params.dry_base, params.dry_allowed_length,
205-
params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
206-
}
207-
break;
208-
case COMMON_SAMPLER_TYPE_TOP_K:
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.top_k));
210-
break;
211-
case COMMON_SAMPLER_TYPE_TOP_P:
212-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p(params.top_p, params.min_keep));
213-
break;
214-
case COMMON_SAMPLER_TYPE_MIN_P:
215-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p(params.min_p, params.min_keep));
216-
break;
217-
case COMMON_SAMPLER_TYPE_N_SIGMA:
218-
if (params.top_n_sigma >= 0) {
219-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
220-
}
221-
break;
222-
case COMMON_SAMPLER_TYPE_XTC:
223-
llama_sampler_chain_add(result->chain,
224-
llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold,
225-
params.min_keep, params.seed));
226-
break;
227-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
228-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical(params.typ_p, params.min_keep));
229-
break;
230-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
231-
llama_sampler_chain_add(
232-
result->chain,
233-
llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
234-
break;
235-
case COMMON_SAMPLER_TYPE_INFILL:
236-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill(vocab));
237-
break;
238-
case COMMON_SAMPLER_TYPE_PENALTIES:
239-
llama_sampler_chain_add(result->chain,
240-
llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat,
241-
params.penalty_freq, params.penalty_present));
242-
break;
243-
default:
244-
GGML_ASSERT(false && "unknown sampler type");
234+
llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
235+
break;
236+
case COMMON_SAMPLER_TYPE_INFILL:
237+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill(vocab));
238+
break;
239+
case COMMON_SAMPLER_TYPE_PENALTIES:
240+
llama_sampler_chain_add(
241+
result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat,
242+
params.penalty_freq, params.penalty_present));
243+
break;
244+
default:
245+
GGML_ASSERT(false && "unknown sampler type");
246+
}
245247
}
246248
}
247-
248249
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
249250
} else if (params.mirostat == 1) {
250251
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));

0 commit comments

Comments
 (0)