Skip to content

Commit 74dc729

Browse files
committed
server : fix logprobs, make it openai-compatible
1 parent 43041d2 commit 74dc729

File tree

4 files changed

+217
-69
lines changed

4 files changed

+217
-69
lines changed

examples/server/server.cpp

Lines changed: 89 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ struct server_task {
342342
}
343343
}
344344

345+
if (params.sampling.n_probs > 0 && params.cache_prompt) {
346+
SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs);
347+
params.cache_prompt = false;
348+
}
349+
345350
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
346351
params.oaicompat_model = json_value(data, "model", model_name);
347352

@@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) {
416421

417422
struct completion_token_output {
418423
llama_token tok;
424+
float prob;
419425
std::string text_to_send;
420426
struct token_prob {
421427
llama_token tok;
@@ -427,25 +433,41 @@ struct completion_token_output {
427433
json to_json() const {
428434
json probs_for_token = json::array();
429435
for (const auto & p : probs) {
436+
std::string tok_str(p.tok_str);
437+
tok_str.resize(validate_utf8(tok_str));
430438
probs_for_token.push_back(json {
431-
{"tok_str", p.tok_str},
432-
{"prob", p.prob},
439+
{"id", p.tok},
440+
{"token", tok_str},
441+
{"bytes", str_to_bytes(p.tok_str)},
442+
{"logprob", p.prob},
433443
});
434444
}
435445
return probs_for_token;
436446
}
437447

438448
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
439449
json out = json::array();
440-
for (const auto & prob : probs) {
441-
const std::string tok_str = prob.text_to_send;
450+
for (const auto & it : probs) {
451+
std::string tok_str(it.text_to_send);
452+
tok_str.resize(validate_utf8(tok_str));
442453
out.push_back(json {
443-
{"content", tok_str},
444-
{"probs", prob.to_json()},
454+
{"id", it.tok},
455+
{"token", tok_str},
456+
{"logprob", it.prob},
457+
{"bytes", str_to_bytes(it.text_to_send)},
458+
{"top_logprobs", it.to_json()},
445459
});
446460
}
447461
return out;
448462
}
463+
464+
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
465+
std::vector<unsigned char> bytes;
466+
for (unsigned char c : str) {
467+
bytes.push_back(c);
468+
}
469+
return bytes;
470+
}
449471
};
450472

451473
struct server_task_result_cmpl_final : server_task_result {
@@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result {
506528
{"tokens_cached", n_tokens_cached},
507529
{"timings", timings.to_json()},
508530
};
509-
if (!probs_output.empty()) {
531+
if (!stream && !probs_output.empty()) {
510532
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
511533
}
512534
return res;
@@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result {
518540
finish_reason = "stop";
519541
}
520542

521-
json choices = json::array({json{
543+
json choice = json{
522544
{"finish_reason", finish_reason},
523545
{"index", 0},
524546
{"message", json{
525547
{"content", content},
526548
{"role", "assistant"}
527549
}
528-
}}});
550+
}};
551+
552+
if (!stream && probs_output.size() > 0) {
553+
choice["logprobs"] = json{
554+
{"content", completion_token_output::probs_vector_to_json(probs_output)},
555+
};
556+
}
529557

530558
std::time_t t = std::time(0);
531559

532560
json res = json {
533-
{"choices", choices},
561+
{"choices", json::array({choice})},
534562
{"created", t},
535563
{"model", oaicompat_model},
536564
{"object", "chat.completion"},
@@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result {
560588
finish_reason = "stop";
561589
}
562590

563-
json choices = json::array({json{{"finish_reason", finish_reason},
564-
{"index", 0},
565-
{"delta", json::object()}}});
591+
json choice = json{
592+
{"finish_reason", finish_reason},
593+
{"index", 0},
594+
{"delta", json::object()}
595+
};
566596

567597
json ret = json {
568-
{"choices", choices},
598+
{"choices", json::array({choice})},
569599
{"created", t},
570600
{"id", oaicompat_cmpl_id},
571601
{"model", oaicompat_model},
@@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result {
592622
int32_t n_decoded;
593623
int32_t n_prompt_tokens;
594624

595-
std::vector<completion_token_output> probs_output;
625+
completion_token_output prob_output;
596626
result_timings timings;
597627

598628
// OAI-compat fields
@@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result {
628658
if (timings.prompt_n > 0) {
629659
res.push_back({"timings", timings.to_json()});
630660
}
631-
if (!probs_output.empty()) {
632-
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
661+
if (!prob_output.probs.empty()) {
662+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
633663
}
634664
return res;
635665
}
@@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result {
681711
}});
682712
}
683713

714+
GGML_ASSERT(choices.size() >= 1);
715+
716+
if (prob_output.probs.size() > 0) {
717+
choices[0]["logprobs"] = json{
718+
{"content", completion_token_output::probs_vector_to_json({prob_output})},
719+
};
720+
}
721+
684722
json ret = json {
685723
{"choices", choices},
686724
{"created", t},
@@ -951,7 +989,6 @@ struct server_slot {
951989

952990
// stats
953991
size_t n_sent_text = 0; // number of sent text character
954-
size_t n_sent_token_probs = 0;
955992

956993
int64_t t_start_process_prompt;
957994
int64_t t_start_generation;
@@ -973,7 +1010,6 @@ struct server_slot {
9731010
stopping_word = "";
9741011
n_past = 0;
9751012
n_sent_text = 0;
976-
n_sent_token_probs = 0;
9771013
task_type = SERVER_TASK_TYPE_COMPLETION;
9781014

9791015
generated_token_probs.clear();
@@ -1713,34 +1749,15 @@ struct server_context {
17131749

17141750
bool process_token(completion_token_output & result, server_slot & slot) {
17151751
// remember which tokens were sampled - used for repetition penalties during sampling
1716-
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
1752+
const std::string token_str = result.text_to_send;
17171753
slot.sampled = result.tok;
17181754

17191755
// search stop word and delete it
17201756
slot.generated_text += token_str;
17211757
slot.has_next_token = true;
17221758

17231759
// check if there is incomplete UTF-8 character at the end
1724-
bool incomplete = false;
1725-
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
1726-
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
1727-
if ((c & 0xC0) == 0x80) {
1728-
// continuation byte: 10xxxxxx
1729-
continue;
1730-
}
1731-
if ((c & 0xE0) == 0xC0) {
1732-
// 2-byte character: 110xxxxx ...
1733-
incomplete = i < 2;
1734-
} else if ((c & 0xF0) == 0xE0) {
1735-
// 3-byte character: 1110xxxx ...
1736-
incomplete = i < 3;
1737-
} else if ((c & 0xF8) == 0xF0) {
1738-
// 4-byte character: 11110xxx ...
1739-
incomplete = i < 4;
1740-
}
1741-
// else 1-byte character or invalid byte
1742-
break;
1743-
}
1760+
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
17441761

17451762
if (!incomplete) {
17461763
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
@@ -1869,6 +1886,29 @@ struct server_context {
18691886
return slot.has_next_token; // continue
18701887
}
18711888

1889+
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
1890+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1891+
const size_t max_probs = cur_p->size;
1892+
1893+
// set prob for the sampled token
1894+
for (size_t i = 0; i < max_probs; ++i) {
1895+
if (result.tok == cur_p->data[i].id) {
1896+
result.prob = cur_p->data[i].p;
1897+
break;
1898+
}
1899+
}
1900+
1901+
// set probs for the top n tokens
1902+
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
1903+
auto tok_id = cur_p->data[i].id;
1904+
result.probs.push_back({
1905+
tok_id,
1906+
tokens_to_output_formatted_string(ctx, tok_id),
1907+
cur_p->data[i].p,
1908+
});
1909+
}
1910+
}
1911+
18721912
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
18731913
send_error(task.id, error, type);
18741914
}
@@ -1906,17 +1946,7 @@ struct server_context {
19061946

19071947
// populate res.probs_output
19081948
if (slot.params.sampling.n_probs > 0) {
1909-
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
1910-
1911-
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1912-
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
1913-
1914-
std::vector<completion_token_output> probs_output;
1915-
if (probs_pos < probs_stop_pos) {
1916-
res->probs_output = std::vector<completion_token_output>(
1917-
slot.generated_token_probs.begin() + probs_pos,
1918-
slot.generated_token_probs.begin() + probs_stop_pos);
1919-
}
1949+
res->prob_output = tkn; // copy the token probs
19201950
}
19211951

19221952
// populate timings if this is final response or timings_per_token is enabled
@@ -2747,17 +2777,12 @@ struct server_context {
27472777
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
27482778

27492779
completion_token_output result;
2750-
result.tok = id;
2780+
result.tok = id;
2781+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2782+
result.prob = 1.0f; // set later
27512783

2752-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2753-
2754-
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
2755-
auto tok_id = cur_p->data[i].id;
2756-
result.probs.push_back({
2757-
tok_id,
2758-
tokens_to_output_formatted_string(ctx, tok_id),
2759-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2760-
});
2784+
if (slot.params.sampling.n_probs > 0) {
2785+
populate_token_probs(slot, result);
27612786
}
27622787

27632788
if (!process_token(result, slot)) {
@@ -2841,7 +2866,9 @@ struct server_context {
28412866
for (size_t i = 0; i < ids.size(); ++i) {
28422867
completion_token_output result;
28432868

2844-
result.tok = ids[i];
2869+
result.tok = ids[i];
2870+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2871+
result.prob = 1.0f; // set later
28452872

28462873
if (!process_token(result, slot)) {
28472874
// release slot because of stop condition

examples/server/tests/unit/test_chat_completion.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
9292
seed=42,
9393
temperature=0.8,
9494
)
95-
print(res)
9695
assert res.choices[0].finish_reason == "length"
9796
assert res.choices[0].message.content is not None
9897
assert match_regex("(Suddenly)+", res.choices[0].message.content)
@@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
163162
assert "predicted_per_second" in data["timings"]
164163
assert "predicted_n" in data["timings"]
165164
assert data["timings"]["predicted_n"] <= 10
165+
166+
167+
def test_logprobs():
168+
global server
169+
server.start()
170+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
171+
res = client.chat.completions.create(
172+
model="gpt-3.5-turbo-instruct",
173+
temperature=0.0,
174+
messages=[
175+
{"role": "system", "content": "Book"},
176+
{"role": "user", "content": "What is the best book"},
177+
],
178+
max_tokens=5,
179+
logprobs=True,
180+
top_logprobs=10,
181+
)
182+
output_text = res.choices[0].message.content
183+
aggregated_text = ''
184+
assert res.choices[0].logprobs is not None
185+
assert res.choices[0].logprobs.content is not None
186+
for token in res.choices[0].logprobs.content:
187+
aggregated_text += token.token
188+
assert 0.0 <= token.logprob <= 1.0
189+
assert token.bytes is not None and len(token.bytes) > 0
190+
assert len(token.top_logprobs) > 0
191+
assert aggregated_text == output_text
192+
193+
194+
def test_logprobs_stream():
195+
global server
196+
server.start()
197+
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
198+
res = client.chat.completions.create(
199+
model="gpt-3.5-turbo-instruct",
200+
temperature=0.0,
201+
messages=[
202+
{"role": "system", "content": "Book"},
203+
{"role": "user", "content": "What is the best book"},
204+
],
205+
max_tokens=5,
206+
logprobs=True,
207+
top_logprobs=10,
208+
stream=True,
209+
)
210+
output_text = ''
211+
aggregated_text = ''
212+
for data in res:
213+
choice = data.choices[0]
214+
if choice.finish_reason is None:
215+
if choice.delta.content:
216+
output_text += choice.delta.content
217+
assert choice.logprobs is not None
218+
assert choice.logprobs.content is not None
219+
for token in choice.logprobs.content:
220+
aggregated_text += token.token
221+
assert 0.0 <= token.logprob <= 1.0
222+
assert token.bytes is not None and len(token.bytes) > 0
223+
assert token.top_logprobs is not None
224+
assert len(token.top_logprobs) > 0
225+
assert aggregated_text == output_text

0 commit comments

Comments
 (0)