Skip to content
25 changes: 23 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

// helpers

llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
return &gsmpl->cur_p;
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
auto * res = &gsmpl->cur_p;

if (do_sort && !res->sorted) {
// remember the selected token before sorting
const llama_token id = res->data[res->selected].id;

std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});

// restore the selected token after sorting
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}

res->sorted = true;
}

return res;
}

llama_token common_sampler_last(const struct common_sampler * gsmpl) {
Expand Down
4 changes: 3 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

// access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
// the .sorted flag of the result indicates whether the returned candidates are sorted
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);

// get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl);
Expand Down
2 changes: 1 addition & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(

common_sampler_sample(smpl, ctx_dft, 0, true);

const auto * cur_p = common_sampler_get_candidates(smpl);
const auto * cur_p = common_sampler_get_candidates(smpl, true);

for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);

auto & dist_tgt = *common_sampler_get_candidates(smpl);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);

float p_tgt = 0.0f;
float p_dft = 0.0f;
Expand Down Expand Up @@ -493,7 +493,7 @@ int main(int argc, char ** argv) {

common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);

const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
Expand Down
7 changes: 1 addition & 6 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ extern "C" {
llama_token_data * data;
size_t size;
int64_t selected; // this is the index in the data array (i.e. not the token id)
bool sorted;
bool sorted; // note: do not assume the data is sorted - always check this flag
} llama_token_data_array;

typedef bool (*llama_progress_callback)(float progress, void * user_data);
Expand Down Expand Up @@ -1156,11 +1156,6 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
"will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");

/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
/// Setting k <= 0 makes this a noop
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
Expand Down
Loading
Loading