Skip to content

Commit e92d53b

Browse files
sampling : optimize samplers by reusing bucket sort (ggml-org#15665)
* sampling : optimize sorting using bucket sort in more places ggml-ci * sampling : do not sort in dist sampler ggml-ci * sampling : avoid heap allocations for sort buffers ggml-ci * common : add option to sort sampling candidates by probability ggml-ci * sampling : revert the change for preserving sort buffers * sampling : use std::copy instead of memcpy * sampling : clarify purpose of partial sort helpers ggml-ci * cont : remove wrong comment [no ci] * common : update comment Co-authored-by: Johannes Gäßler <[email protected]> --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 0d161f0 commit e92d53b

File tree

9 files changed

+228
-172
lines changed

9 files changed

+228
-172
lines changed

common/sampling.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
426426

427427
// helpers
428428

429-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
430-
return &gsmpl->cur_p;
429+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
430+
auto * res = &gsmpl->cur_p;
431+
432+
if (do_sort && !res->sorted) {
433+
// remember the selected token before sorting
434+
const llama_token id = res->data[res->selected].id;
435+
436+
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
437+
return a.p > b.p;
438+
});
439+
440+
// restore the selected token after sorting
441+
for (size_t i = 0; i < res->size; ++i) {
442+
if (res->data[i].id == id) {
443+
res->selected = i;
444+
break;
445+
}
446+
}
447+
448+
res->sorted = true;
449+
}
450+
451+
return res;
431452
}
432453

433454
llama_token common_sampler_last(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8686
// helpers
8787

8888
// access the internal list of current candidate tokens
89-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
89+
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
90+
// the .sorted flag of the result indicates whether the returned candidates are sorted
91+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
9092

9193
// get the last accepted token
9294
llama_token common_sampler_last(const struct common_sampler * gsmpl);

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
317317

318318
common_sampler_sample(smpl, ctx_dft, 0, true);
319319

320-
const auto * cur_p = common_sampler_get_candidates(smpl);
320+
const auto * cur_p = common_sampler_get_candidates(smpl, true);
321321

322322
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
323323
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
244244
// stochastic verification
245245
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
246246

247-
auto & dist_tgt = *common_sampler_get_candidates(smpl);
247+
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
248248

249249
float p_tgt = 0.0f;
250250
float p_dft = 0.0f;
@@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
493493

494494
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
495495

496-
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
496+
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
497497

498498
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
499499
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",

include/llama.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ extern "C" {
206206
llama_token_data * data;
207207
size_t size;
208208
int64_t selected; // this is the index in the data array (i.e. not the token id)
209-
bool sorted;
209+
bool sorted; // note: do not assume the data is sorted - always check this flag
210210
} llama_token_data_array;
211211

212212
typedef bool (*llama_progress_callback)(float progress, void * user_data);
@@ -1156,11 +1156,6 @@ extern "C" {
11561156
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
11571157
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
11581158

1159-
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1160-
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
1161-
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
1162-
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
1163-
11641159
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
11651160
/// Setting k <= 0 makes this a noop
11661161
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);

0 commit comments

Comments
 (0)