Skip to content

Commit 8284dc4

Browse files
committed
draft implementation of v1/completions echo logprobs support
1 parent 1425f58 commit 8284dc4

File tree

2 files changed

+210
-15
lines changed

2 files changed

+210
-15
lines changed

tools/server/server.cpp

Lines changed: 210 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ struct slot_params {
112112
bool stream = true;
113113
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
114114
bool return_tokens = false;
115+
bool echo = false;
115116

116117
int32_t n_keep = 0; // number of tokens to keep from initial prompt
117118
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160161
}
161162

162163
return json {
164+
{"echo", echo},
163165
{"n_predict", n_predict}, // Server configured n_predict
164166
{"seed", sampling.seed},
165167
{"temperature", sampling.temp},
@@ -265,6 +267,7 @@ struct server_task {
265267
params.stream = json_value(data, "stream", false);
266268
params.cache_prompt = json_value(data, "cache_prompt", true);
267269
params.return_tokens = json_value(data, "return_tokens", false);
270+
params.echo = json_value(data, "echo", false);
268271
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
269272
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
270273
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
@@ -674,6 +677,98 @@ struct completion_token_output {
674677
return out;
675678
}
676679

680+
static json oaicompat_probs_vector_to_json(
681+
const std::vector<completion_token_output> & probs_out,
682+
bool post_sampling_probs,
683+
bool echo,
684+
const std::vector<completion_token_output> & prompt_probs = {}
685+
) {
686+
json out = json::object();
687+
688+
std::vector<std::string> tokens;
689+
std::vector<completion_token_output> all_probs;
690+
691+
if (echo && !prompt_probs.empty()) {
692+
all_probs.insert(all_probs.end(), prompt_probs.begin(), prompt_probs.end());
693+
}
694+
695+
all_probs.insert(all_probs.end(), probs_out.begin(), probs_out.end());
696+
697+
tokens.reserve(all_probs.size());
698+
for (const auto & p : all_probs) {
699+
std::string piece = p.text_to_send;
700+
piece.resize(validate_utf8(piece));
701+
tokens.push_back(piece);
702+
}
703+
704+
int text_offset = 0;
705+
std::vector<int> text_offsets;
706+
text_offsets.reserve(tokens.size());
707+
708+
int current_off = text_offset;
709+
for (const auto & tok : tokens) {
710+
text_offsets.push_back(current_off);
711+
current_off += static_cast<int>(tok.size());
712+
}
713+
714+
std::vector<std::optional<float>> token_logprobs;
715+
token_logprobs.reserve(all_probs.size());
716+
717+
std::vector<std::optional<std::unordered_map<std::string, float>>> top_logprobs;
718+
top_logprobs.reserve(all_probs.size());
719+
720+
for (size_t i = 0; i < all_probs.size(); ++i) {
721+
const auto & p = all_probs[i];
722+
723+
bool is_first_prompt_token = echo && (i == 0);
724+
725+
if (is_first_prompt_token) {
726+
token_logprobs.push_back(std::nullopt);
727+
top_logprobs.push_back(std::nullopt);
728+
} else {
729+
if (std::isinf(p.prob) && p.prob < 0) {
730+
token_logprobs.push_back(std::nullopt);
731+
top_logprobs.push_back(std::nullopt);
732+
} else {
733+
float logprob_value = p.prob;
734+
if (!post_sampling_probs) {
735+
logprob_value = p.prob;
736+
} else {
737+
logprob_value = p.prob > 0.0f ? std::log(p.prob) : -std::numeric_limits<float>::infinity();
738+
}
739+
740+
token_logprobs.push_back(std::optional<float>(logprob_value));
741+
742+
std::unordered_map<std::string, float> top_map;
743+
for (const auto & cand : p.probs) {
744+
std::string cand_txt = cand.txt;
745+
cand_txt.resize(validate_utf8(cand_txt));
746+
747+
float cand_logprob;
748+
if (!post_sampling_probs) {
749+
cand_logprob = cand.prob;
750+
} else {
751+
cand_logprob = cand.prob > 0.0f ? std::log(cand.prob) : -std::numeric_limits<float>::infinity();
752+
}
753+
754+
top_map[cand_txt] = cand_logprob;
755+
}
756+
757+
top_logprobs.push_back(std::move(top_map));
758+
}
759+
}
760+
}
761+
762+
out = json{
763+
{"text_offset", text_offsets},
764+
{"token_logprobs", token_logprobs},
765+
{"tokens", tokens},
766+
{"top_logprobs", top_logprobs}
767+
};
768+
769+
return out;
770+
}
771+
677772
static float logarithm(float x) {
678773
// nlohmann::json converts -inf to null, so we need to prevent that
679774
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
@@ -697,6 +792,7 @@ struct server_task_result_cmpl_final : server_task_result {
697792
bool stream;
698793
result_timings timings;
699794
std::string prompt;
795+
bool echo = false;
700796

701797
bool truncated;
702798
int32_t n_decoded;
@@ -708,6 +804,7 @@ struct server_task_result_cmpl_final : server_task_result {
708804

709805
bool post_sampling_probs;
710806
std::vector<completion_token_output> probs_output;
807+
std::vector<completion_token_output> prompt_probs_output;
711808
std::vector<std::string> response_fields;
712809

713810
slot_params generation_params;
@@ -769,19 +866,26 @@ struct server_task_result_cmpl_final : server_task_result {
769866
json to_json_oaicompat() {
770867
std::time_t t = std::time(0);
771868
json logprobs = json(nullptr); // OAI default to null
772-
if (!stream && probs_output.size() > 0) {
773-
logprobs = json{
774-
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
775-
};
869+
if (!stream && (probs_output.size() > 0 || (echo && prompt_probs_output.size() > 0))) {
870+
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
871+
probs_output,
872+
post_sampling_probs,
873+
echo,
874+
prompt_probs_output
875+
);
776876
}
777877
json finish_reason = "length";
778878
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779879
finish_reason = "stop";
780880
}
881+
std::string response_text = content;
882+
if (echo && !stream) {
883+
response_text = prompt + content;
884+
}
781885
json res = json {
782886
{"choices", json::array({
783887
json{
784-
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
888+
{"text", stream ? "" : response_text}, // in stream mode, content is already in last partial chunk
785889
{"index", index},
786890
{"logprobs", logprobs},
787891
{"finish_reason", finish_reason},
@@ -940,6 +1044,10 @@ struct server_task_result_cmpl_partial : server_task_result {
9401044
std::string oaicompat_cmpl_id;
9411045
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
9421046

1047+
bool echo = false;
1048+
std::string prompt_text;
1049+
bool is_first_chunk = false;
1050+
9431051
virtual int get_index() override {
9441052
return index;
9451053
}
@@ -986,14 +1094,21 @@ struct server_task_result_cmpl_partial : server_task_result {
9861094
std::time_t t = std::time(0);
9871095
json logprobs = json(nullptr); // OAI default to null
9881096
if (prob_output.probs.size() > 0) {
989-
logprobs = json{
990-
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
991-
};
1097+
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
1098+
std::vector<completion_token_output>{prob_output},
1099+
post_sampling_probs,
1100+
echo
1101+
);
1102+
}
1103+
1104+
std::string response_text = content;
1105+
if (echo && is_first_chunk) {
1106+
response_text = prompt_text + content;
9921107
}
9931108
json res = json {
9941109
{"choices", json::array({
9951110
json{
996-
{"text", content},
1111+
{"text", response_text},
9971112
{"index", index},
9981113
{"logprobs", logprobs},
9991114
{"finish_reason", nullptr},
@@ -1321,6 +1436,8 @@ struct server_slot {
13211436

13221437
// input prompt tokens
13231438
server_tokens prompt_tokens;
1439+
std::string prompt_text;
1440+
std::vector<completion_token_output> prompt_token_probs;
13241441

13251442
size_t last_nl_pos = 0;
13261443

@@ -1368,6 +1485,7 @@ struct server_slot {
13681485
SLT_DBG(*this, "%s", "\n");
13691486

13701487
n_prompt_tokens = 0;
1488+
prompt_text = "";
13711489
last_nl_pos = 0;
13721490
generated_text = "";
13731491
has_new_line = false;
@@ -1381,6 +1499,7 @@ struct server_slot {
13811499

13821500
generated_tokens.clear();
13831501
generated_token_probs.clear();
1502+
prompt_token_probs.clear();
13841503
chat_msg = {};
13851504
json_schema = json();
13861505
generated_tool_call_ids.clear();
@@ -2240,6 +2359,77 @@ struct server_context {
22402359
slot.params = std::move(task.params);
22412360
slot.prompt_tokens = std::move(task.prompt_tokens);
22422361

2362+
if (slot.params.echo) {
2363+
slot.prompt_text = slot.prompt_tokens.detokenize(ctx, true);
2364+
2365+
if (slot.params.sampling.n_probs > 0 && slot.prompt_tokens.size() > 0 && slot.prompt_token_probs.empty()) {
2366+
slot.prompt_token_probs.reserve(slot.prompt_tokens.size());
2367+
2368+
llama_batch batch = llama_batch_init(slot.prompt_tokens.size(), 0, 1);
2369+
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
2370+
common_batch_add(batch, slot.prompt_tokens[i], i, {0}, 1);
2371+
}
2372+
2373+
if (llama_decode(ctx, batch) == 0) {
2374+
const int n_vocab = llama_vocab_n_tokens(vocab);
2375+
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
2376+
completion_token_output prompt_token;
2377+
prompt_token.tok = slot.prompt_tokens[i];
2378+
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);
2379+
2380+
if (i > 0 && i < slot.prompt_tokens.size()) {
2381+
const float * logits = llama_get_logits_ith(ctx, i - 1);
2382+
if (logits != nullptr) {
2383+
std::vector<std::pair<float, llama_token>> logits_id;
2384+
logits_id.reserve(n_vocab);
2385+
2386+
for (int j = 0; j < n_vocab; j++) {
2387+
logits_id.emplace_back(logits[j], j);
2388+
}
2389+
2390+
prompt_token.probs.clear();
2391+
size_t top_k = std::min(logits_id.size(), static_cast<size_t>(slot.params.sampling.n_probs));
2392+
for (size_t k = 0; k < top_k; ++k) {
2393+
completion_token_output::prob_info prob_info;
2394+
prob_info.tok = logits_id[k].second;
2395+
prob_info.prob = logits_id[k].first;
2396+
prob_info.txt = common_token_to_piece(ctx, logits_id[k].second, true);
2397+
prompt_token.probs.push_back(prob_info);
2398+
}
2399+
2400+
auto actual_token_it = std::find_if(logits_id.begin(), logits_id.end(),
2401+
[&](const std::pair<float, llama_token> & pair) {
2402+
return pair.second == slot.prompt_tokens[i];
2403+
});
2404+
2405+
if (actual_token_it != logits_id.end()) {
2406+
prompt_token.prob = actual_token_it->first;
2407+
} else {
2408+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2409+
}
2410+
} else {
2411+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2412+
}
2413+
} else {
2414+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2415+
}
2416+
2417+
slot.prompt_token_probs.push_back(prompt_token);
2418+
}
2419+
} else {
2420+
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
2421+
completion_token_output prompt_token;
2422+
prompt_token.tok = slot.prompt_tokens[i];
2423+
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);
2424+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2425+
slot.prompt_token_probs.push_back(prompt_token);
2426+
}
2427+
}
2428+
2429+
llama_batch_free(batch);
2430+
}
2431+
}
2432+
22432433
if (!are_lora_equal(slot.params.lora, slot.lora)) {
22442434
// if lora is changed, we cannot reuse cached tokens
22452435
slot.cache_tokens.clear();
@@ -2529,6 +2719,10 @@ struct server_context {
25292719
res->content = tkn.text_to_send;
25302720
res->tokens = { tkn.tok };
25312721

2722+
res->echo = slot.params.echo;
2723+
res->prompt_text = slot.prompt_text;
2724+
res->is_first_chunk = (slot.n_decoded == 1);
2725+
25322726
res->n_decoded = slot.n_decoded;
25332727
res->n_prompt_tokens = slot.n_prompt_tokens;
25342728
res->post_sampling_probs = slot.params.post_sampling_probs;
@@ -2562,7 +2756,9 @@ struct server_context {
25622756
res->content = slot.generated_text;
25632757
res->tokens = std::move(slot.generated_tokens);
25642758
res->timings = slot.get_timings();
2565-
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
2759+
2760+
res->echo = slot.params.echo;
2761+
res->prompt = slot.params.echo ? slot.prompt_text : slot.prompt_tokens.detokenize(ctx, true);
25662762
res->response_fields = std::move(slot.params.response_fields);
25672763

25682764
res->truncated = slot.truncated;
@@ -2595,6 +2791,10 @@ struct server_context {
25952791
slot.generated_token_probs.begin(),
25962792
slot.generated_token_probs.end());
25972793
}
2794+
2795+
if (slot.params.echo && !slot.prompt_token_probs.empty()) {
2796+
res->prompt_probs_output = slot.prompt_token_probs;
2797+
}
25982798
}
25992799

26002800
res->generation_params = slot.params; // copy the parameters

tools/server/utils.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,6 @@ static json oaicompat_completion_params_parse(const json & body) {
553553
throw std::runtime_error("Only one completion choice is allowed");
554554
}
555555

556-
// Handle "echo" field
557-
if (json_value(body, "echo", false)) {
558-
throw std::runtime_error("Only no echo is supported");
559-
}
560-
561556
// Params supported by OAI but unsupported by llama.cpp
562557
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
563558
for (const auto & param : unsupported_params) {

0 commit comments

Comments
 (0)