Skip to content

Commit ecadd37

Browse files
committed
add post_sampling_probs option
1 parent c0cca53 commit ecadd37

File tree

4 files changed

+151
-79
lines changed

4 files changed

+151
-79
lines changed

examples/server/README.md

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -449,52 +449,56 @@ These words will not be included in the completion, so make sure to add them to
449449

450450
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
451451

452+
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
453+
452454
**Response format**
453455

454456
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
455457

456458
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
457-
458-
```json
459-
{
460-
"content": "<the generated completion text>",
461-
"tokens": [ generated token ids if requested ],
462-
...
463-
"probs": [
464-
{
465-
"id": <token id>,
466-
"logprob": float,
467-
"token": "<most likely token>",
468-
"bytes": [int, int, ...],
469-
"top_logprobs": [
470-
{
471-
"id": <token id>,
472-
"logprob": float,
473-
"token": "<token text>",
474-
"bytes": [int, int, ...],
475-
},
476-
{
477-
"id": <token id>,
478-
"logprob": float,
479-
"token": "<token text>",
480-
"bytes": [int, int, ...],
481-
},
482-
...
483-
]
484-
},
485-
{
486-
"id": <token id>,
487-
"logprob": float,
488-
"token": "<most likely token>",
489-
"bytes": [int, int, ...],
490-
"top_logprobs": [
491-
...
492-
]
493-
},
459+
```json
460+
{
461+
"content": "<the generated completion text>",
462+
"tokens": [ generated token ids if requested ],
494463
...
495-
]
496-
},
497-
```
464+
"probs": [
465+
{
466+
"id": <token id>,
467+
"logprob": float,
468+
"token": "<most likely token>",
469+
"bytes": [int, int, ...],
470+
"top_logprobs": [
471+
{
472+
"id": <token id>,
473+
"logprob": float,
474+
"token": "<token text>",
475+
"bytes": [int, int, ...],
476+
},
477+
{
478+
"id": <token id>,
479+
"logprob": float,
480+
"token": "<token text>",
481+
"bytes": [int, int, ...],
482+
},
483+
...
484+
]
485+
},
486+
{
487+
"id": <token id>,
488+
"logprob": float,
489+
"token": "<most likely token>",
490+
"bytes": [int, int, ...],
491+
"top_logprobs": [
492+
...
493+
]
494+
},
495+
...
496+
]
497+
},
498+
```
499+
Please note that if `post_sampling_probs` is set to `true`:
500+
- `logprob` will be replace with `prob`, with the value between 0.0 and 1.0
501+
- Returned number of probabilities may be less than `n_probs`
498502

499503
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
500504
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.

examples/server/server.cpp

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ struct slot_params {
9393

9494
std::vector<std::string> antiprompt;
9595
bool timings_per_token = false;
96+
bool post_sampling_probs = false;
9697
bool ignore_eos = false;
9798

9899
struct common_params_sampling sampling;
@@ -151,6 +152,7 @@ struct slot_params {
151152
{"speculative.n_min", speculative.n_min},
152153
{"speculative.p_min", speculative.p_min},
153154
{"timings_per_token", timings_per_token},
155+
{"post_sampling_probs", post_sampling_probs},
154156
};
155157
}
156158
};
@@ -231,6 +233,7 @@ struct server_task {
231233
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
232234
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
233235
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
236+
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
234237

235238
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
236239
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
@@ -449,7 +452,7 @@ struct completion_token_output {
449452
};
450453
std::vector<token_prob> probs;
451454

452-
json to_json() const {
455+
json to_json(bool post_sampling_probs) const {
453456
json probs_for_token = json::array();
454457
for (const auto & p : probs) {
455458
std::string tok_str(p.tok_str);
@@ -458,23 +461,29 @@ struct completion_token_output {
458461
{"id", p.tok},
459462
{"token", tok_str},
460463
{"bytes", str_to_bytes(p.tok_str)},
461-
{"logprob", logarithm(p.prob)},
464+
{
465+
post_sampling_probs ? "prob" : "logprob",
466+
post_sampling_probs ? p.prob : logarithm(p.prob)
467+
},
462468
});
463469
}
464470
return probs_for_token;
465471
}
466472

467-
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
473+
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
468474
json out = json::array();
469475
for (const auto & it : probs) {
470476
std::string tok_str(it.text_to_send);
471477
tok_str.resize(validate_utf8(tok_str));
472478
out.push_back(json {
473479
{"id", it.tok},
474480
{"token", tok_str},
475-
{"logprob", logarithm(it.prob)},
476481
{"bytes", str_to_bytes(it.text_to_send)},
477-
{"top_logprobs", it.to_json()},
482+
{"top_logprobs", it.to_json(post_sampling_probs)},
483+
{
484+
post_sampling_probs ? "prob" : "logprob",
485+
post_sampling_probs ? it.prob : logarithm(it.prob)
486+
},
478487
});
479488
}
480489
return out;
@@ -512,6 +521,7 @@ struct server_task_result_cmpl_final : server_task_result {
512521
std::string stopping_word;
513522
stop_type stop = STOP_TYPE_NONE;
514523

524+
bool post_sampling_probs;
515525
std::vector<completion_token_output> probs_output;
516526

517527
slot_params generation_params;
@@ -557,7 +567,7 @@ struct server_task_result_cmpl_final : server_task_result {
557567
{"timings", timings.to_json()},
558568
};
559569
if (!stream && !probs_output.empty()) {
560-
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
570+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
561571
}
562572
return res;
563573
}
@@ -579,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result {
579589

580590
if (!stream && probs_output.size() > 0) {
581591
choice["logprobs"] = json{
582-
{"content", completion_token_output::probs_vector_to_json(probs_output)},
592+
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
583593
};
584594
}
585595

@@ -652,6 +662,7 @@ struct server_task_result_cmpl_partial : server_task_result {
652662
int32_t n_decoded;
653663
int32_t n_prompt_tokens;
654664

665+
bool post_sampling_probs;
655666
completion_token_output prob_output;
656667
result_timings timings;
657668

@@ -690,7 +701,7 @@ struct server_task_result_cmpl_partial : server_task_result {
690701
res.push_back({"timings", timings.to_json()});
691702
}
692703
if (!prob_output.probs.empty()) {
693-
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
704+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
694705
}
695706
return res;
696707
}
@@ -746,7 +757,7 @@ struct server_task_result_cmpl_partial : server_task_result {
746757

747758
if (prob_output.probs.size() > 0) {
748759
choices[0]["logprobs"] = json{
749-
{"content", completion_token_output::probs_vector_to_json({prob_output})},
760+
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
750761
};
751762
}
752763

@@ -1944,28 +1955,53 @@ struct server_context {
19441955
return slot.has_next_token; // continue
19451956
}
19461957

1947-
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
1948-
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1949-
int n_vocab = llama_n_vocab(llama_get_model(ctx));
1958+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
19501959
size_t n_probs = slot.params.sampling.n_probs;
1951-
1952-
bool found_sampled_tok = false;
1953-
result.probs.reserve(n_probs);
1954-
for (int i = 0; i < n_vocab; i++) {
1955-
// set probability for sampled token
1956-
if (cur[i].id == result.tok) {
1957-
found_sampled_tok = true;
1958-
result.prob = cur[i].p;
1960+
int n_vocab = llama_n_vocab(llama_get_model(ctx));
1961+
if (post_sampling) {
1962+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1963+
1964+
bool found_sampled_tok = false;
1965+
result.probs.reserve(n_probs);
1966+
for (int i = 0; i < n_vocab; i++) {
1967+
// set probability for sampled token
1968+
if (cur[i].id == result.tok) {
1969+
found_sampled_tok = true;
1970+
result.prob = cur[i].p;
1971+
}
1972+
// set probability for top n_probs tokens
1973+
result.probs.push_back({
1974+
cur[i].id,
1975+
common_detokenize(ctx, {cur[i].id}, special),
1976+
cur[i].p
1977+
});
1978+
// break if we have all the necessary data
1979+
if (result.probs.size() == n_probs && found_sampled_tok) {
1980+
break;
1981+
}
19591982
}
1960-
// set probability for top n_probs tokens
1961-
result.probs.push_back({
1962-
cur[i].id,
1963-
common_detokenize(ctx, {cur[i].id}, special),
1964-
cur[i].p
1965-
});
1966-
// break if we have all the necessary data
1967-
if (result.probs.size() == n_probs && found_sampled_tok) {
1968-
break;
1983+
} else {
1984+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1985+
const size_t max_probs = cur_p->size;
1986+
1987+
bool found_sampled_tok = false;
1988+
result.probs.reserve(max_probs);
1989+
for (size_t i = 0; i < max_probs; i++) {
1990+
// set probability for sampled token
1991+
if (cur_p->data[i].id == result.tok) {
1992+
found_sampled_tok = true;
1993+
result.prob = cur_p->data[i].p;
1994+
}
1995+
// set probability for top n_probs tokens
1996+
result.probs.push_back({
1997+
cur_p->data[i].id,
1998+
common_detokenize(ctx, {cur_p->data[i].id}, special),
1999+
cur_p->data[i].p
2000+
});
2001+
// break if we have all the necessary data
2002+
if (result.probs.size() == n_probs && found_sampled_tok) {
2003+
break;
2004+
}
19692005
}
19702006
}
19712007
}
@@ -1997,8 +2033,9 @@ struct server_context {
19972033
res->content = tkn.text_to_send;
19982034
res->tokens = { tkn.tok };
19992035

2000-
res->n_decoded = slot.n_decoded;
2001-
res->n_prompt_tokens = slot.n_prompt_tokens;
2036+
res->n_decoded = slot.n_decoded;
2037+
res->n_prompt_tokens = slot.n_prompt_tokens;
2038+
res->post_sampling_probs = slot.params.post_sampling_probs;
20022039

20032040
res->verbose = slot.params.verbose;
20042041
res->oaicompat = slot.params.oaicompat;
@@ -2030,13 +2067,14 @@ struct server_context {
20302067
res->timings = slot.get_timings();
20312068
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
20322069

2033-
res->truncated = slot.truncated;
2034-
res->n_decoded = slot.n_decoded;
2035-
res->n_prompt_tokens = slot.n_prompt_tokens;
2036-
res->n_tokens_cached = slot.n_past;
2037-
res->has_new_line = slot.has_new_line;
2038-
res->stopping_word = slot.stopping_word;
2039-
res->stop = slot.stop;
2070+
res->truncated = slot.truncated;
2071+
res->n_decoded = slot.n_decoded;
2072+
res->n_prompt_tokens = slot.n_prompt_tokens;
2073+
res->n_tokens_cached = slot.n_past;
2074+
res->has_new_line = slot.has_new_line;
2075+
res->stopping_word = slot.stopping_word;
2076+
res->stop = slot.stop;
2077+
res->post_sampling_probs = slot.params.post_sampling_probs;
20402078

20412079
res->verbose = slot.params.verbose;
20422080
res->stream = slot.params.stream;
@@ -2859,7 +2897,7 @@ struct server_context {
28592897
result.prob = 1.0f; // set later
28602898

28612899
if (slot.params.sampling.n_probs > 0) {
2862-
populate_token_probs(slot, result, params_base.special, tok_idx);
2900+
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
28632901
}
28642902

28652903
if (!process_token(result, slot)) {

examples/server/tests/unit/test_completion.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,30 @@ def test_n_probs_stream():
309309
assert "token" in prob and type(prob["token"]) == str
310310
assert "logprob" in prob and prob["logprob"] <= 0.0
311311
assert "bytes" in prob and type(prob["bytes"]) == list
312+
313+
314+
def test_n_probs_post_sampling():
315+
global server
316+
server.multi_token_probs = True
317+
server.start()
318+
res = server.make_request("POST", "/completion", data={
319+
"prompt": "I believe the meaning of life is",
320+
"n_probs": 10,
321+
"temperature": 0.0,
322+
"n_predict": 5,
323+
"post_sampling_probs": True,
324+
})
325+
assert res.status_code == 200
326+
assert "completion_probabilities" in res.body
327+
assert len(res.body["completion_probabilities"]) == 5
328+
for tok in res.body["completion_probabilities"]:
329+
assert "id" in tok and tok["id"] > 0
330+
assert "token" in tok and type(tok["token"]) == str
331+
assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0
332+
assert "bytes" in tok and type(tok["bytes"]) == list
333+
assert len(tok["top_logprobs"]) == 10
334+
for prob in tok["top_logprobs"]:
335+
assert "id" in prob and prob["id"] > 0
336+
assert "token" in prob and type(prob["token"]) == str
337+
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
338+
assert "bytes" in prob and type(prob["bytes"]) == list

examples/server/tests/unit/test_embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def test_embedding_multiple():
5050
@pytest.mark.parametrize(
5151
"input,is_multi_prompt",
5252
[
53+
# do not crash on empty input
54+
("", False),
5355
# single prompt
5456
("string", False),
5557
([12, 34, 56], False),
@@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai():
103105

104106
# /v1/embeddings does not support pooling type 'none'
105107
assert res.status_code == 400
108+
assert "error" in res.body
106109

107110

108111
def test_embedding_openai_library_single():

0 commit comments

Comments
 (0)