-
Notifications
You must be signed in to change notification settings - Fork 13.4k
sampling: add K-Shift sampler #10048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 28 commits
87384fb
070f954
5237aa4
48b715d
6101797
ee95274
c95c957
968b4ba
e83245e
9c233c7
4d1ab99
62a878b
c616263
5144fd9
2b7be22
aa458d1
9ef8cb5
f853c3e
ae8b7eb
e5ce8b4
8411453
df01a89
77afcd1
af46dc2
31b6bea
5ed18f9
f7d3fe1
840a2b1
877a495
9cae93c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,6 +188,17 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) | |
| cur_p->size = k; | ||
| } | ||
|
|
||
| static void llama_sampler_top_shift_impl(llama_token_data_array * cur_p, int k) { | ||
| // sort before shifting | ||
| std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { | ||
| return a.logit > b.logit; | ||
| }); | ||
|
|
||
| // shift to a token #[k] | ||
| cur_p->data += k; | ||
| cur_p->size -= k; | ||
| } | ||
|
|
||
| static uint32_t get_rng_seed(uint32_t seed) { | ||
| if (seed == LLAMA_DEFAULT_SEED) { | ||
| // use system clock if std::random_device is not a true RNG | ||
|
|
@@ -1082,6 +1093,64 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, | |
| }; | ||
| } | ||
|
|
||
| // k-shift | ||
|
|
||
| struct llama_sampler_k_shift { | ||
| const int32_t k; | ||
| bool k_set; | ||
| }; | ||
|
|
||
| static const char * llama_sampler_k_shift_name(const struct llama_sampler * /*smpl*/) { | ||
| return "k-shift"; | ||
| } | ||
|
|
||
| static void llama_sampler_k_shift_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||
| auto * ctx = (llama_sampler_k_shift *) smpl->ctx; | ||
|
|
||
| if (ctx->k_set == true | ||
| || ctx->k <= 0 | ||
| || ctx->k >= (int) cur_p->size) { | ||
|
||
| return; | ||
| } | ||
|
|
||
| llama_sampler_top_shift_impl(cur_p, ctx->k); | ||
| ctx->k_set = true; | ||
| } | ||
|
|
||
| static struct llama_sampler * llama_sampler_k_shift_clone(const struct llama_sampler * smpl) { | ||
| auto * ctx = (const llama_sampler_k_shift *) smpl->ctx; | ||
|
|
||
| return llama_sampler_init_k_shift(ctx->k); | ||
| } | ||
|
|
||
| static void llama_sampler_k_shift_free(struct llama_sampler * smpl) { | ||
| delete (llama_sampler_k_shift *) smpl->ctx; | ||
| } | ||
|
|
||
| static void llama_sampler_k_shift_reset(struct llama_sampler * smpl) { | ||
| auto * ctx = (llama_sampler_k_shift *) smpl->ctx; | ||
| ctx->k_set = false; | ||
| } | ||
|
|
||
| static struct llama_sampler_i llama_sampler_k_shift_i = { | ||
| /* .name = */ llama_sampler_k_shift_name, | ||
| /* .accept = */ nullptr, | ||
| /* .apply = */ llama_sampler_k_shift_apply, | ||
| /* .reset = */ llama_sampler_k_shift_reset, | ||
| /* .clone = */ llama_sampler_k_shift_clone, | ||
| /* .free = */ llama_sampler_k_shift_free, | ||
| }; | ||
|
|
||
| struct llama_sampler * llama_sampler_init_k_shift(int32_t k) { | ||
| return new llama_sampler { | ||
| /* .iface = */ &llama_sampler_k_shift_i, | ||
| /* .ctx = */ new llama_sampler_k_shift { | ||
| /* .k = */ k, | ||
| /* .k_set = */ false, | ||
| }, | ||
| }; | ||
| } | ||
|
|
||
| // mirostat | ||
|
|
||
| struct llama_sampler_mirostat { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not match the paper, in which (AFAICT) exactly the
k-th token is selected, rather than sampling from all tokens except the topk-1ones.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The choice of token will be handled by the final step in sampling queue, so greedy sampling would be needed to match the effect described in paper. Considering that specifically greedy sampler was removed recently, I don't think introducing another "final step" sampler would be ok.