Skip to content
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
87384fb
K-Shift commit
MaggotHATE Oct 25, 2024
070f954
Fixed style
MaggotHATE Oct 25, 2024
5237aa4
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 25, 2024
48b715d
Fixes and tests
MaggotHATE Oct 26, 2024
6101797
Merge branch 'k-shift2' of https://github.com/MaggotHATE/llama.cpp-gr…
MaggotHATE Oct 26, 2024
ee95274
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 26, 2024
c95c957
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 28, 2024
968b4ba
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 28, 2024
e83245e
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 29, 2024
9c233c7
Merge branch 'master' into k-shift2
MaggotHATE Oct 29, 2024
4d1ab99
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 29, 2024
62a878b
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 30, 2024
c616263
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 30, 2024
5144fd9
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Oct 31, 2024
2b7be22
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 1, 2024
aa458d1
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 1, 2024
9ef8cb5
Removed custom reset
MaggotHATE Nov 1, 2024
f853c3e
Revert back reset function
MaggotHATE Nov 1, 2024
ae8b7eb
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 1, 2024
e5ce8b4
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 2, 2024
8411453
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 2, 2024
df01a89
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 3, 2024
77afcd1
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 4, 2024
af46dc2
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 4, 2024
31b6bea
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 5, 2024
5ed18f9
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 6, 2024
f7d3fe1
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 7, 2024
840a2b1
Merge branch 'ggerganov:master' into k-shift2
MaggotHATE Nov 7, 2024
877a495
Fix to guarantee K-Shift on the first step only
MaggotHATE Nov 8, 2024
9cae93c
Added `shift_p_min` parameter to control probabilities
MaggotHATE Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sparams.temp = std::max(params.sparams.temp, 0.0f);
}
).set_sparam());
add_opt(common_arg(
{"--k-shift"}, "N",
string_format("k-shift sampling (default: %d, 0 = disabled)", params.sparams.k_shift),
[](common_params & params, int value) {
params.sparams.k_shift = value;
}
).set_sparam());
add_opt(common_arg(
{"--top-k"}, "N",
string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,7 @@ void yaml_dump_non_result_info(FILE * stream, const common_params & params, cons
yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector);

fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency());
fprintf(stream, "k_shift: %d # default: 0\n", sparams.k_shift);
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
Expand Down
19 changes: 11 additions & 8 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ enum llama_example {
enum common_sampler_type {
COMMON_SAMPLER_TYPE_NONE = 0,
COMMON_SAMPLER_TYPE_DRY = 1,
COMMON_SAMPLER_TYPE_TOP_K = 2,
COMMON_SAMPLER_TYPE_TOP_P = 3,
COMMON_SAMPLER_TYPE_MIN_P = 4,
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_K_SHIFT = 2,
COMMON_SAMPLER_TYPE_TOP_K = 3,
COMMON_SAMPLER_TYPE_TOP_P = 4,
COMMON_SAMPLER_TYPE_MIN_P = 5,
//COMMON_SAMPLER_TYPE_TFS_Z = 6,
COMMON_SAMPLER_TYPE_TYPICAL_P = 7,
COMMON_SAMPLER_TYPE_TEMPERATURE = 8,
COMMON_SAMPLER_TYPE_XTC = 9,
COMMON_SAMPLER_TYPE_INFILL = 10,
};

// dimensionality reduction methods, used by cvector-generator
Expand All @@ -108,6 +109,7 @@ struct common_sampler_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t k_shift = 0; // 0 = disabled
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
Expand Down Expand Up @@ -137,6 +139,7 @@ struct common_sampler_params {

std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_K_SHIFT,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
Expand Down
12 changes: 10 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ std::string common_sampler_params::print() const {
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tk_shift = %d, top_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
k_shift, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
mirostat, mirostat_eta, mirostat_tau);

return std::string(result);
Expand Down Expand Up @@ -187,6 +187,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_K_SHIFT:
llama_sampler_chain_add(result->chain, llama_sampler_init_k_shift (params.k_shift));
break;
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
Expand Down Expand Up @@ -369,6 +372,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return 'd';
case COMMON_SAMPLER_TYPE_K_SHIFT: return 's';
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
Expand All @@ -383,6 +387,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY: return "dry";
case COMMON_SAMPLER_TYPE_K_SHIFT: return "k_shift";
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
Expand All @@ -398,6 +403,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "k_shift", COMMON_SAMPLER_TYPE_K_SHIFT },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
Expand All @@ -410,6 +416,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
// make it ready for both system names and input names
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "k-shift", COMMON_SAMPLER_TYPE_K_SHIFT },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
Expand Down Expand Up @@ -443,6 +450,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
std::unordered_map<char, common_sampler_type> sampler_name_map = {
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_K_SHIFT), COMMON_SAMPLER_TYPE_K_SHIFT },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
Expand Down
8 changes: 8 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ DRY sampling provides more nuanced control over text generation, particularly fo

Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"`

### K-Shift Sampling

- `--k-shift N`: Shift the first token selection by cutting out N tokens from the top once (default: 0).

K-Shift is a sampling method that guides models away from the most obvious output, eliciting reasoning and analysis. It cuts out k top tokens once at the beginning of inference, making sure that the dialog will start from a less obvious path without guiding the model too much. The method was mentoned in a paper [Chain-of-Thought Reasoning without Prompting](https://arxiv.org/pdf/2402.10200) as a simple trick to guiding a model towards reasoning. In practice, K-Shift can improve the quality of reasoning, help bypass bias or censorship in certain cases, and may also be used as a diagnostics tool. K-Shift is intended to be used with greedy sampling (`--k-shift 10 --top-k 1`), but can help with creative writing too - albeit, not as much as XTC. The default value is 0.

Example usage: `--k-shift 10`

### Top-K Sampling

- `--top-k N`: Limit the next token selection to the K most probable tokens (default: 40).
Expand Down
4 changes: 3 additions & 1 deletion examples/server/public/index-new.html
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
top_k: 0, // <= 0 to use vocab size
k_shift: 0, // <= 0 to use vocab size
top_k: 0, // 0 = disabled
top_p: 1.0, // 1.0 = disabled
min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4
xtc_probability: 0.0, // 0 = disabled;
Expand Down Expand Up @@ -834,6 +835,7 @@
<details>
<summary><span class="summary-title">Further Options</span></summary>
<fieldset class="params">
${IntField({ label: "K-Shift", title: "Cuts out first k tokens once at the start of sampling. Intended to use with greedy sampling.", max: 100, min: 0, step: 1, name: "k_shift", value: params.value.k_shift })}
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })}
${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })}
${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })}
Expand Down
2 changes: 2 additions & 0 deletions examples/server/public/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@
dry_base: 1.75, // 0.0 = disabled
dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
k_shift: 0, // 0 = disabled
top_k: 40, // <= 0 to use vocab size
top_p: 0.95, // 1.0 = disabled
min_p: 0.05, // 0 = disabled
Expand Down Expand Up @@ -1007,6 +1008,7 @@
${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
${IntField({ label: "K-shift", max: 100, min: -1, name: "k_shift", value: params.value.k_shift })}
${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}
Expand Down
2 changes: 2 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ struct server_context {
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
slot.sparams.k_shift = json_value(data, "k_shift", default_sparams.k_shift);
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
Expand Down Expand Up @@ -1140,6 +1141,7 @@ struct server_context {
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
{"k_shift", slot.sparams.k_shift},
{"top_k", slot.sparams.top_k},
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,9 @@ extern "C" {
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);


LLAMA_API struct llama_sampler * llama_sampler_init_k_shift (int32_t k);

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down
69 changes: 69 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k;
}

static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) {
// sort before shifting
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});

// shift to a token #[k]
cur_p->data += k;
cur_p->size -= k;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not match the paper, in which (AFAICT) exactly the k-th token is selected, rather than sampling from all tokens except the top k-1 ones.

Copy link
Contributor Author

@MaggotHATE MaggotHATE Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choice of token will be handled by the final step in sampling queue, so greedy sampling would be needed to match the effect described in paper. Considering that specifically greedy sampler was removed recently, I don't think introducing another "final step" sampler would be ok.

}

static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
Expand Down Expand Up @@ -1082,6 +1093,64 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
};
}

// k-shift

struct llama_sampler_k_shift {
const int32_t k;
bool k_set;
};

static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) {
return "k-shift";
}

static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;

if (ctx->k_set == true
|| ctx->k <= 0
|| ctx->k >= (int) cur_p->size) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be a bug here: If at the first token position, ctx->k >= (int) cur_p->size (e.g. because a preceding truncation sampler has already removed too many tokens) then we return and k_set remains false. This means that K-Shift will only take effect on the second (or later) token of the output, violating its contract.

return;
}

llama_sampler_top_shift_impl(cur_p, ctx->k);
ctx->k_set = true;
}

static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) {
auto * ctx = (const llama_sampler_k_shift *) smpl->ctx;

return llama_sampler_init_k_shift(ctx->k);
}

static void llama_sampler_k_shift_free(struct llama_sampler * smpl) {
delete (llama_sampler_k_shift *) smpl->ctx;
}

static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_k_shift *) smpl->ctx;
ctx->k_set = false;
}

static struct llama_sampler_i llama_sampler_k_shift_i = {
/* .name = */ llama_sampler_k_shift_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_k_shift_apply,
/* .reset = */ llama_sampler_k_shift_reset,
/* .clone = */ llama_sampler_k_shift_clone,
/* .free = */ llama_sampler_k_shift_free,
};

struct llama_sampler * llama_sampler_init_k_shift(int32_t k) {
return new llama_sampler {
/* .iface = */ &llama_sampler_k_shift_i,
/* .ctx = */ new llama_sampler_k_shift {
/* .k = */ k,
/* .k_set = */ false,
},
};
}

// mirostat

struct llama_sampler_mirostat {
Expand Down
29 changes: 24 additions & 5 deletions tests/test-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ static void test_temp_ext(const std::vector<float> & probs, const std::vector<fl
tester.check();
}

static void test_k_shift(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);

DUMP(&tester.cur_p);
tester.apply(llama_sampler_init_k_shift(k));
tester.apply(llama_sampler_init_dist (0));
DUMP(&tester.cur_p);

tester.check();
}

static void test_top_k(const std::vector<float> & probs, const std::vector<float> & probs_expected, int k) {
sampler_tester tester(probs, probs_expected);

Expand Down Expand Up @@ -288,11 +299,13 @@ static void test_perf() {
data.emplace_back(llama_token_data{i, logit, 0.0f});
}

BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_typical(0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);

BENCH(llama_sampler_init_k_shift (10), data, 32);
BENCH(llama_sampler_init_top_k (40), data, 32);
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
}

int main(void) {
Expand All @@ -304,6 +317,12 @@ int main(void) {
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);

test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 3);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.66666f, 0.33333f}, 2);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.5f, 0.33333f, 0.16666f}, 1);
test_k_shift({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);

test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
Expand Down
Loading