Skip to content

Commit ca992ad

Browse files
committed
Integrate top_n_sigma into main sampler chain
1 parent a558d3a commit ca992ad

File tree

1 file changed

+41
-45
lines changed

1 file changed

+41
-45
lines changed

common/sampling.cpp

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -188,52 +188,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
188188
params.logit_bias.data()));
189189

190190
if (params.mirostat == 0) {
191-
if (params.top_n_sigma >= 0) {
192-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
193-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
195-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
196-
} else {
197-
for (const auto & cnstr : params.samplers) {
198-
switch (cnstr) {
199-
case COMMON_SAMPLER_TYPE_DRY:
200-
{
201-
std::vector<const char *> c_breakers;
202-
c_breakers.reserve(params.dry_sequence_breakers.size());
203-
for (const auto & str : params.dry_sequence_breakers) {
204-
c_breakers.push_back(str.c_str());
205-
}
206-
207-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
191+
for (const auto & cnstr : params.samplers) {
192+
switch (cnstr) {
193+
case COMMON_SAMPLER_TYPE_DRY:
194+
{
195+
std::vector<const char *> c_breakers;
196+
c_breakers.reserve(params.dry_sequence_breakers.size());
197+
for (const auto & str : params.dry_sequence_breakers) {
198+
c_breakers.push_back(str.c_str());
208199
}
209-
break;
210-
case COMMON_SAMPLER_TYPE_TOP_K:
211-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
212-
break;
213-
case COMMON_SAMPLER_TYPE_TOP_P:
214-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
215-
break;
216-
case COMMON_SAMPLER_TYPE_MIN_P:
217-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
218-
break;
219-
case COMMON_SAMPLER_TYPE_XTC:
220-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
221-
break;
222-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
223-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
224-
break;
225-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
226-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
227-
break;
228-
case COMMON_SAMPLER_TYPE_INFILL:
229-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
230-
break;
231-
case COMMON_SAMPLER_TYPE_PENALTIES:
232-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
233-
break;
234-
default:
235-
GGML_ASSERT(false && "unknown sampler type");
236-
}
200+
201+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
202+
}
203+
break;
204+
case COMMON_SAMPLER_TYPE_TOP_K:
205+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
206+
break;
207+
case COMMON_SAMPLER_TYPE_TOP_P:
208+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
209+
break;
210+
case COMMON_SAMPLER_TYPE_MIN_P:
211+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
212+
break;
213+
case COMMON_SAMPLER_TYPE_XTC:
214+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
215+
break;
216+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
217+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
218+
break;
219+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
220+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
221+
break;
222+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
223+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma));
224+
break;
225+
case COMMON_SAMPLER_TYPE_INFILL:
226+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
227+
break;
228+
case COMMON_SAMPLER_TYPE_PENALTIES:
229+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
230+
break;
231+
default:
232+
GGML_ASSERT(false && "unknown sampler type");
237233
}
238234
}
239235
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));

0 commit comments

Comments
 (0)