Skip to content

Commit cc90cdb

Browse files
committed
return pre-sampling p
1 parent 01afafe commit cc90cdb

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

examples/server/server.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,25 +1886,17 @@ struct server_context {
18861886
return slot.has_next_token; // continue
18871887
}
18881888

1889-
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
1890-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1891-
const size_t max_probs = cur_p->size;
1892-
1893-
// set prob for the sampled token
1894-
for (size_t i = 0; i < max_probs; ++i) {
1895-
if (result.tok == cur_p->data[i].id) {
1896-
result.prob = cur_p->data[i].p;
1897-
break;
1898-
}
1899-
}
1889+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
1890+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1891+
int n_vocab = llama_n_vocab(llama_get_model(ctx));
19001892

1901-
// set probs for the top n tokens
1902-
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
1903-
auto tok_id = cur_p->data[i].id;
1893+
// only take at most n_probs tokens
1894+
const int n_probs = slot.params.sampling.n_probs;
1895+
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
19041896
result.probs.push_back({
1905-
tok_id,
1906-
tokens_to_output_formatted_string(ctx, tok_id),
1907-
cur_p->data[i].p,
1897+
cur[i].id,
1898+
common_detokenize(ctx, {cur[i].id}, special),
1899+
cur[i].p
19081900
});
19091901
}
19101902
}
@@ -2758,7 +2750,9 @@ struct server_context {
27582750
continue; // continue loop of slots
27592751
}
27602752

2761-
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2753+
const int tok_idx = slot.i_batch - i;
2754+
2755+
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
27622756

27632757
slot.i_batch = -1;
27642758

@@ -2782,7 +2776,7 @@ struct server_context {
27822776
result.prob = 1.0f; // set later
27832777

27842778
if (slot.params.sampling.n_probs > 0) {
2785-
populate_token_probs(slot, result);
2779+
populate_token_probs(slot, result, params_base.special, tok_idx);
27862780
}
27872781

27882782
if (!process_token(result, slot)) {

examples/server/utils.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
694694
static std::string safe_json_to_str(json data) {
695695
return data.dump(-1, ' ', false, json::error_handler_t::replace);
696696
}
697+
698+
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
699+
std::vector<llama_token_data> cur;
700+
const auto * logits = llama_get_logits_ith(ctx, idx);
701+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
702+
703+
cur.resize(n_vocab);
704+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
705+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
706+
}
707+
708+
// apply softmax
709+
float max_l = cur[0].logit;
710+
float cum_sum = 0.0f;
711+
for (size_t i = 0; i < cur.size(); ++i) {
712+
float p = expf(cur[i].logit - max_l);
713+
cur[i].p = p;
714+
cum_sum += p;
715+
}
716+
for (size_t i = 0; i < cur.size(); ++i) {
717+
cur[i].p /= cum_sum;
718+
}
719+
720+
// sort tokens by probability
721+
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
722+
return a.p > b.p;
723+
});
724+
725+
return cur;
726+
}

0 commit comments

Comments
 (0)