Skip to content

Commit ed7f2d5

Browse files
committed
set p for sampled token
1 parent 22b72c8 commit ed7f2d5

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

examples/server/server.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,15 +1889,26 @@ struct server_context {
18891889
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
18901890
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
18911891
int n_vocab = llama_n_vocab(llama_get_model(ctx));
1892+
size_t n_probs = slot.params.sampling.n_probs;
18921893

1893-
// only take at most n_probs tokens
1894-
const int n_probs = slot.params.sampling.n_probs;
1895-
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
1894+
bool found_sampled_tok = false;
1895+
result.probs.reserve(n_probs);
1896+
for (int i = 0; i < n_vocab; i++) {
1897+
// set probability for sampled token
1898+
if (cur[i].id == result.tok) {
1899+
found_sampled_tok = true;
1900+
result.prob = cur[i].p;
1901+
}
1902+
// set probability for top n_probs tokens
18961903
result.probs.push_back({
18971904
cur[i].id,
18981905
common_detokenize(ctx, {cur[i].id}, special),
18991906
cur[i].p
19001907
});
1908+
// break if we have all the necessary data
1909+
if (result.probs.size() == n_probs && found_sampled_tok) {
1910+
break;
1911+
}
19011912
}
19021913
}
19031914

0 commit comments

Comments
 (0)