Skip to content

Commit d2463dc

Browse files
committed
resolve review comments
1 parent fd4cf34 commit d2463dc

File tree

3 files changed

+36
-33
lines changed

3 files changed

+36
-33
lines changed

examples/server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ These words will not be included in the completion, so make sure to add them to
496496
},
497497
```
498498
Please note that if `post_sampling_probs` is set to `true`:
499-
- `logprob` will be replace with `prob`, with the value between 0.0 and 1.0
500-
- `top_logprobs` will be replace with `top_probs`. Each element inside contains:
499+
- `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0
500+
- `top_logprobs` will be replaced with `top_probs`. Each element contains:
501501
- `id`: token ID
502502
- `token`: token in string
503503
- `bytes`: token in bytes

examples/server/server.cpp

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -443,20 +443,20 @@ struct completion_token_output {
443443
std::string text_to_send;
444444
struct token_prob {
445445
llama_token tok;
446-
std::string tok_str;
446+
std::string txt;
447447
float prob;
448448
};
449449
std::vector<token_prob> probs;
450450

451451
json to_json(bool post_sampling_probs) const {
452452
json probs_for_token = json::array();
453453
for (const auto & p : probs) {
454-
std::string tok_str(p.tok_str);
455-
tok_str.resize(validate_utf8(tok_str));
454+
std::string txt(p.txt);
455+
txt.resize(validate_utf8(txt));
456456
probs_for_token.push_back(json {
457457
{"id", p.tok},
458-
{"token", tok_str},
459-
{"bytes", str_to_bytes(p.tok_str)},
458+
{"token", txt},
459+
{"bytes", str_to_bytes(p.txt)},
460460
{
461461
post_sampling_probs ? "prob" : "logprob",
462462
post_sampling_probs ? p.prob : logarithm(p.prob)
@@ -468,20 +468,20 @@ struct completion_token_output {
468468

469469
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
470470
json out = json::array();
471-
for (const auto & it : probs) {
472-
std::string tok_str(it.text_to_send);
473-
tok_str.resize(validate_utf8(tok_str));
471+
for (const auto & p : probs) {
472+
std::string txt(p.text_to_send);
473+
txt.resize(validate_utf8(txt));
474474
out.push_back(json {
475-
{"id", it.tok},
476-
{"token", tok_str},
477-
{"bytes", str_to_bytes(it.text_to_send)},
475+
{"id", p.tok},
476+
{"token", txt},
477+
{"bytes", str_to_bytes(p.text_to_send)},
478478
{
479479
post_sampling_probs ? "top_probs" : "top_logprobs",
480-
it.to_json(post_sampling_probs)
480+
p.to_json(post_sampling_probs)
481481
},
482482
{
483483
post_sampling_probs ? "prob" : "logprob",
484-
post_sampling_probs ? it.prob : logarithm(it.prob)
484+
post_sampling_probs ? p.prob : logarithm(p.prob)
485485
},
486486
});
487487
}
@@ -1958,44 +1958,45 @@ struct server_context {
19581958
size_t n_probs = slot.params.sampling.n_probs;
19591959
int n_vocab = llama_n_vocab(llama_get_model(ctx));
19601960
if (post_sampling) {
1961-
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1961+
// TODO: optimize this with min-p optimization
1962+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1963+
const size_t max_probs = cur_p->size;
19621964

19631965
bool found_sampled_tok = false;
1964-
result.probs.reserve(n_probs);
1965-
for (int i = 0; i < n_vocab; i++) {
1966+
result.probs.reserve(max_probs);
1967+
for (size_t i = 0; i < max_probs; i++) {
19661968
// set probability for sampled token
1967-
if (cur[i].id == result.tok) {
1969+
if (cur_p->data[i].id == result.tok) {
19681970
found_sampled_tok = true;
1969-
result.prob = cur[i].p;
1971+
result.prob = cur_p->data[i].p;
19701972
}
19711973
// set probability for top n_probs tokens
19721974
result.probs.push_back({
1973-
cur[i].id,
1974-
common_detokenize(ctx, {cur[i].id}, special),
1975-
cur[i].p
1975+
cur_p->data[i].id,
1976+
common_detokenize(ctx, {cur_p->data[i].id}, special),
1977+
cur_p->data[i].p
19761978
});
19771979
// break if we have all the necessary data
19781980
if (result.probs.size() == n_probs && found_sampled_tok) {
19791981
break;
19801982
}
19811983
}
19821984
} else {
1983-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1984-
const size_t max_probs = cur_p->size;
1985+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
19851986

19861987
bool found_sampled_tok = false;
1987-
result.probs.reserve(max_probs);
1988-
for (size_t i = 0; i < max_probs; i++) {
1988+
result.probs.reserve(n_probs);
1989+
for (int i = 0; i < n_vocab; i++) {
19891990
// set probability for sampled token
1990-
if (cur_p->data[i].id == result.tok) {
1991+
if (cur[i].id == result.tok) {
19911992
found_sampled_tok = true;
1992-
result.prob = cur_p->data[i].p;
1993+
result.prob = cur[i].p;
19931994
}
19941995
// set probability for top n_probs tokens
19951996
result.probs.push_back({
1996-
cur_p->data[i].id,
1997-
common_detokenize(ctx, {cur_p->data[i].id}, special),
1998-
cur_p->data[i].p
1997+
cur[i].id,
1998+
common_detokenize(ctx, {cur[i].id}, special),
1999+
cur[i].p
19992000
});
20002001
// break if we have all the necessary data
20012002
if (result.probs.size() == n_probs && found_sampled_tok) {

examples/server/tests/unit/test_completion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,13 @@ def test_n_probs_post_sampling():
325325
for tok in res.body["completion_probabilities"]:
326326
assert "id" in tok and tok["id"] > 0
327327
assert "token" in tok and type(tok["token"]) == str
328-
assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0
328+
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
329329
assert "bytes" in tok and type(tok["bytes"]) == list
330330
assert len(tok["top_probs"]) == 10
331331
for prob in tok["top_probs"]:
332332
assert "id" in prob and prob["id"] > 0
333333
assert "token" in prob and type(prob["token"]) == str
334334
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
335335
assert "bytes" in prob and type(prob["bytes"]) == list
336+
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
337+
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])

0 commit comments

Comments
 (0)