Skip to content

Commit ddc3c22

Browse files
committed
initial sampling changes:
1 parent f7cd133 commit ddc3c22

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ enum common_sampler_type {
9595
COMMON_SAMPLER_TYPE_XTC = 8,
9696
COMMON_SAMPLER_TYPE_INFILL = 9,
9797
COMMON_SAMPLER_TYPE_PENALTIES = 10,
98+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11
9899
};
99100

100101
// dimensionality reduction methods, used by cvector-generator
@@ -128,6 +129,7 @@ struct common_params_sampling {
128129
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
129130
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
130131
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
132+
int32_t top_n_sigma = 2;
131133
float mirostat_tau = 5.00f; // target entropy
132134
float mirostat_eta = 0.10f; // learning rate
133135
bool ignore_eos = false;
@@ -146,6 +148,7 @@ struct common_params_sampling {
146148
COMMON_SAMPLER_TYPE_MIN_P,
147149
COMMON_SAMPLER_TYPE_XTC,
148150
COMMON_SAMPLER_TYPE_TEMPERATURE,
151+
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
149152
};
150153

151154
std::string grammar; // optional BNF-like grammar to constrain sampling

common/sampling.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,28 +176,32 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
176176
}
177177
break;
178178
case COMMON_SAMPLER_TYPE_TOP_K:
179-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
179+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180180
break;
181181
case COMMON_SAMPLER_TYPE_TOP_P:
182-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
182+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183183
break;
184184
case COMMON_SAMPLER_TYPE_MIN_P:
185-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
185+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186186
break;
187187
case COMMON_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
188+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189189
break;
190190
case COMMON_SAMPLER_TYPE_TYPICAL_P:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
191+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
192192
break;
193193
case COMMON_SAMPLER_TYPE_TEMPERATURE:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
194+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195195
break;
196196
case COMMON_SAMPLER_TYPE_INFILL:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
197+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
198198
break;
199199
case COMMON_SAMPLER_TYPE_PENALTIES:
200-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
200+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
201+
break;
202+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
203+
// llama_sampler_chain_add(result->chain, )
204+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma(params.top_n_sigma))
201205
break;
202206
default:
203207
GGML_ASSERT(false && "unknown sampler type");
@@ -407,6 +411,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
407411
case COMMON_SAMPLER_TYPE_XTC: return 'x';
408412
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
409413
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
414+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
410415
default : return '?';
411416
}
412417
}
@@ -422,6 +427,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
422427
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
423428
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
424429
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
430+
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
425431
default : return "";
426432
}
427433
}
@@ -437,6 +443,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
437443
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
438444
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
439445
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
446+
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
440447
};
441448

442449
// since samplers names are written multiple ways
@@ -451,6 +458,9 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
451458
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
452459
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
453460
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
461+
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
462+
{ "top-nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
463+
{ "top_nsigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
454464
};
455465

456466
std::vector<common_sampler_type> samplers;
@@ -484,6 +494,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
484494
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
485495
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
486496
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
497+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA}
487498
};
488499

489500
std::vector<common_sampler_type> samplers;

src/llama-sampling.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,48 @@ struct llama_sampler * llama_sampler_init_penalties(
16451645
};
16461646
}
16471647

1648+
// top-n-sigma
1649+
1650+
struct llama_sampler_top_n_sigma {
1651+
const int32_t n;
1652+
};
1653+
1654+
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1655+
return "top-n-sigma";
1656+
}
1657+
1658+
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1659+
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1660+
llama_sampler_top_n_sigma_impl(cur_p, ctx->n);
1661+
}
1662+
1663+
// static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
1664+
// const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
1665+
// return llama_sampler_init_top_k(ctx->k);
1666+
// }
1667+
1668+
// static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
1669+
// delete (llama_sampler_top_k *) smpl->ctx;
1670+
// }
1671+
1672+
// static struct llama_sampler_i llama_sampler_top_k_i = {
1673+
// /* .name = */ llama_sampler_top_k_name,
1674+
// /* .accept = */ nullptr,
1675+
// /* .apply = */ llama_sampler_top_k_apply,
1676+
// /* .reset = */ nullptr,
1677+
// /* .clone = */ llama_sampler_top_k_clone,
1678+
// /* .free = */ llama_sampler_top_k_free,
1679+
// };
1680+
1681+
// struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
1682+
// return new llama_sampler {
1683+
// /* .iface = */ &llama_sampler_top_k_i,
1684+
// /* .ctx = */ new llama_sampler_top_k {
1685+
// /* .k = */ k,
1686+
// },
1687+
// };
1688+
// }
1689+
16481690
// DRY
16491691

16501692
struct llama_sampler_dry {

0 commit comments

Comments
 (0)