Skip to content

Commit 6817897

Browse files
committed
implementation of v1/completions echo logprobs support
1 parent 1425f58 commit 6817897

File tree

2 files changed

+239
-15
lines changed

2 files changed

+239
-15
lines changed

tools/server/server.cpp

Lines changed: 239 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ struct slot_params {
112112
bool stream = true;
113113
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
114114
bool return_tokens = false;
115+
bool echo = false;
115116

116117
int32_t n_keep = 0; // number of tokens to keep from initial prompt
117118
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160161
}
161162

162163
return json {
164+
{"echo", echo},
163165
{"n_predict", n_predict}, // Server configured n_predict
164166
{"seed", sampling.seed},
165167
{"temperature", sampling.temp},
@@ -265,6 +267,7 @@ struct server_task {
265267
params.stream = json_value(data, "stream", false);
266268
params.cache_prompt = json_value(data, "cache_prompt", true);
267269
params.return_tokens = json_value(data, "return_tokens", false);
270+
params.echo = json_value(data, "echo", false);
268271
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
269272
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
270273
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
@@ -674,6 +677,91 @@ struct completion_token_output {
674677
return out;
675678
}
676679

680+
static json oaicompat_probs_vector_to_json(
681+
const std::vector<completion_token_output> & probs_out,
682+
bool post_sampling_probs,
683+
bool echo,
684+
const std::vector<completion_token_output> & prompt_probs = {}
685+
) {
686+
json out = json::object();
687+
688+
std::vector<std::string> tokens;
689+
std::vector<completion_token_output> all_probs;
690+
691+
if (echo && !prompt_probs.empty()) {
692+
all_probs.insert(all_probs.end(), prompt_probs.begin(), prompt_probs.end());
693+
}
694+
695+
all_probs.insert(all_probs.end(), probs_out.begin(), probs_out.end());
696+
697+
tokens.reserve(all_probs.size());
698+
for (const auto & p : all_probs) {
699+
std::string piece = p.text_to_send;
700+
piece.resize(validate_utf8(piece));
701+
tokens.push_back(piece);
702+
}
703+
704+
int text_offset = 0;
705+
std::vector<int> text_offsets;
706+
text_offsets.reserve(tokens.size());
707+
708+
int current_off = text_offset;
709+
for (const auto & tok : tokens) {
710+
text_offsets.push_back(current_off);
711+
current_off += static_cast<int>(tok.size());
712+
}
713+
714+
std::vector<std::optional<float>> token_logprobs;
715+
token_logprobs.reserve(all_probs.size());
716+
717+
std::vector<std::optional<std::unordered_map<std::string, float>>> top_logprobs;
718+
top_logprobs.reserve(all_probs.size());
719+
720+
for (size_t i = 0; i < all_probs.size(); ++i) {
721+
const auto & p = all_probs[i];
722+
723+
if (std::isinf(p.prob) && p.prob < 0) {
724+
token_logprobs.push_back(std::nullopt);
725+
top_logprobs.push_back(std::nullopt);
726+
} else {
727+
float logprob_value = p.prob;
728+
if (!post_sampling_probs) {
729+
logprob_value = p.prob;
730+
} else {
731+
logprob_value = p.prob > 0.0f ? std::log(p.prob) : -std::numeric_limits<float>::infinity();
732+
}
733+
734+
token_logprobs.push_back(std::optional<float>(logprob_value));
735+
736+
std::unordered_map<std::string, float> top_map;
737+
for (const auto & cand : p.probs) {
738+
std::string cand_txt = cand.txt;
739+
cand_txt.resize(validate_utf8(cand_txt));
740+
741+
float cand_logprob;
742+
if (!post_sampling_probs) {
743+
cand_logprob = cand.prob;
744+
} else {
745+
cand_logprob = cand.prob > 0.0f ? std::log(cand.prob) : -std::numeric_limits<float>::infinity();
746+
}
747+
748+
top_map[cand_txt] = cand_logprob;
749+
}
750+
751+
top_logprobs.push_back(std::move(top_map));
752+
}
753+
}
754+
755+
out = json{
756+
{"text_offset", text_offsets},
757+
{"token_logprobs", token_logprobs},
758+
{"tokens", tokens},
759+
{"top_logprobs", top_logprobs}
760+
};
761+
762+
return out;
763+
}
764+
677765
static float logarithm(float x) {
678766
// nlohmann::json converts -inf to null, so we need to prevent that
679767
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
@@ -697,6 +785,7 @@ struct server_task_result_cmpl_final : server_task_result {
697785
bool stream;
698786
result_timings timings;
699787
std::string prompt;
788+
bool echo = false;
700789

701790
bool truncated;
702791
int32_t n_decoded;
@@ -708,6 +797,7 @@ struct server_task_result_cmpl_final : server_task_result {
708797

709798
bool post_sampling_probs;
710799
std::vector<completion_token_output> probs_output;
800+
std::vector<completion_token_output> prompt_probs_output;
711801
std::vector<std::string> response_fields;
712802

713803
slot_params generation_params;
@@ -769,19 +859,26 @@ struct server_task_result_cmpl_final : server_task_result {
769859
json to_json_oaicompat() {
770860
std::time_t t = std::time(0);
771861
json logprobs = json(nullptr); // OAI default to null
772-
if (!stream && probs_output.size() > 0) {
773-
logprobs = json{
774-
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
775-
};
862+
if (!stream && (probs_output.size() > 0 || (echo && prompt_probs_output.size() > 0))) {
863+
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
864+
probs_output,
865+
post_sampling_probs,
866+
echo,
867+
prompt_probs_output
868+
);
776869
}
777870
json finish_reason = "length";
778871
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779872
finish_reason = "stop";
780873
}
874+
std::string response_text = content;
875+
if (echo && !stream) {
876+
response_text = prompt + content;
877+
}
781878
json res = json {
782879
{"choices", json::array({
783880
json{
784-
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
881+
{"text", stream ? "" : response_text}, // in stream mode, content is already in last partial chunk
785882
{"index", index},
786883
{"logprobs", logprobs},
787884
{"finish_reason", finish_reason},
@@ -940,6 +1037,10 @@ struct server_task_result_cmpl_partial : server_task_result {
9401037
std::string oaicompat_cmpl_id;
9411038
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
9421039

1040+
bool echo = false;
1041+
std::string prompt_text;
1042+
bool is_first_chunk = false;
1043+
9431044
virtual int get_index() override {
9441045
return index;
9451046
}
@@ -986,14 +1087,21 @@ struct server_task_result_cmpl_partial : server_task_result {
9861087
std::time_t t = std::time(0);
9871088
json logprobs = json(nullptr); // OAI default to null
9881089
if (prob_output.probs.size() > 0) {
989-
logprobs = json{
990-
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
991-
};
1090+
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
1091+
std::vector<completion_token_output>{prob_output},
1092+
post_sampling_probs,
1093+
echo
1094+
);
1095+
}
1096+
1097+
std::string response_text = content;
1098+
if (echo && is_first_chunk) {
1099+
response_text = prompt_text + content;
9921100
}
9931101
json res = json {
9941102
{"choices", json::array({
9951103
json{
996-
{"text", content},
1104+
{"text", response_text},
9971105
{"index", index},
9981106
{"logprobs", logprobs},
9991107
{"finish_reason", nullptr},
@@ -1321,6 +1429,8 @@ struct server_slot {
13211429

13221430
// input prompt tokens
13231431
server_tokens prompt_tokens;
1432+
std::string prompt_text;
1433+
std::vector<completion_token_output> prompt_token_probs;
13241434

13251435
size_t last_nl_pos = 0;
13261436

@@ -1368,6 +1478,7 @@ struct server_slot {
13681478
SLT_DBG(*this, "%s", "\n");
13691479

13701480
n_prompt_tokens = 0;
1481+
prompt_text = "";
13711482
last_nl_pos = 0;
13721483
generated_text = "";
13731484
has_new_line = false;
@@ -1381,6 +1492,7 @@ struct server_slot {
13811492

13821493
generated_tokens.clear();
13831494
generated_token_probs.clear();
1495+
prompt_token_probs.clear();
13841496
chat_msg = {};
13851497
json_schema = json();
13861498
generated_tool_call_ids.clear();
@@ -2240,6 +2352,113 @@ struct server_context {
22402352
slot.params = std::move(task.params);
22412353
slot.prompt_tokens = std::move(task.prompt_tokens);
22422354

2355+
if (slot.params.echo) {
2356+
slot.prompt_text = slot.prompt_tokens.detokenize(ctx, true);
2357+
2358+
if (slot.params.sampling.n_probs > 0 && slot.prompt_tokens.size() > 1 && slot.prompt_token_probs.empty()) {
2359+
slot.prompt_token_probs.reserve(slot.prompt_tokens.size());
2360+
2361+
llama_memory_clear(llama_get_memory(ctx), true);
2362+
2363+
const int n_batch = llama_n_batch(ctx);
2364+
const int num_batches = (slot.prompt_tokens.size() + n_batch - 1) / n_batch;
2365+
const int n_vocab = llama_vocab_n_tokens(vocab);
2366+
2367+
std::vector<float> all_logits;
2368+
if (num_batches > 1) {
2369+
all_logits.reserve(slot.prompt_tokens.size() * n_vocab);
2370+
}
2371+
2372+
for (int batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
2373+
const int batch_start = batch_idx * n_batch;
2374+
const int batch_size = std::min((int)slot.prompt_tokens.size() - batch_start, n_batch);
2375+
2376+
llama_batch batch = llama_batch_init(batch_size, 0, 1);
2377+
for (int i = 0; i < batch_size; ++i) {
2378+
common_batch_add(batch, slot.prompt_tokens[batch_start + i], batch_start + i, {0}, true);
2379+
}
2380+
2381+
if (llama_decode(ctx, batch) == 0) {
2382+
const float * batch_logits = llama_get_logits(ctx);
2383+
if (num_batches > 1) {
2384+
all_logits.insert(all_logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
2385+
}
2386+
} else {
2387+
llama_batch_free(batch);
2388+
break;
2389+
}
2390+
llama_batch_free(batch);
2391+
}
2392+
2393+
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
2394+
completion_token_output prompt_token;
2395+
prompt_token.tok = slot.prompt_tokens[i];
2396+
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);
2397+
2398+
if (i == 0) {
2399+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2400+
} else {
2401+
const float * logits = num_batches > 1 ?
2402+
all_logits.data() + (i - 1) * n_vocab :
2403+
llama_get_logits_ith(ctx, i - 1);
2404+
2405+
if (logits != nullptr) {
2406+
float max_logit = logits[0];
2407+
for (int j = 1; j < n_vocab; ++j) {
2408+
max_logit = std::max(max_logit, logits[j]);
2409+
}
2410+
2411+
double sum_exp = 0.0;
2412+
for (int j = 0; j < n_vocab; ++j) {
2413+
sum_exp += expf(logits[j] - max_logit);
2414+
}
2415+
2416+
const float log_sum_exp = max_logit + logf(sum_exp);
2417+
prompt_token.prob = logits[slot.prompt_tokens[i]] - log_sum_exp;
2418+
2419+
if (slot.params.sampling.n_probs > 0) {
2420+
std::vector<std::pair<float, llama_token>> logits_id;
2421+
logits_id.reserve(n_vocab);
2422+
2423+
for (int j = 0; j < n_vocab; j++) {
2424+
const float logprob = logits[j] - log_sum_exp;
2425+
logits_id.emplace_back(logprob, j);
2426+
}
2427+
2428+
std::partial_sort(logits_id.begin(),
2429+
logits_id.begin() + std::min((size_t)slot.params.sampling.n_probs, logits_id.size()),
2430+
logits_id.end(),
2431+
[](const auto& a, const auto& b) { return a.first > b.first; });
2432+
2433+
prompt_token.probs.clear();
2434+
size_t top_k = std::min(logits_id.size(), static_cast<size_t>(slot.params.sampling.n_probs));
2435+
for (size_t k = 0; k < top_k; ++k) {
2436+
completion_token_output::prob_info prob_info;
2437+
prob_info.tok = logits_id[k].second;
2438+
prob_info.prob = logits_id[k].first;
2439+
prob_info.txt = common_token_to_piece(ctx, logits_id[k].second, true);
2440+
prompt_token.probs.push_back(prob_info);
2441+
}
2442+
}
2443+
} else {
2444+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2445+
}
2446+
}
2447+
2448+
slot.prompt_token_probs.push_back(prompt_token);
2449+
}
2450+
} else {
2451+
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
2452+
completion_token_output prompt_token;
2453+
prompt_token.tok = slot.prompt_tokens[i];
2454+
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);
2455+
prompt_token.prob = -std::numeric_limits<float>::infinity();
2456+
slot.prompt_token_probs.push_back(prompt_token);
2457+
}
2458+
}
2459+
}
2460+
2461+
22432462
if (!are_lora_equal(slot.params.lora, slot.lora)) {
22442463
// if lora is changed, we cannot reuse cached tokens
22452464
slot.cache_tokens.clear();
@@ -2529,6 +2748,10 @@ struct server_context {
25292748
res->content = tkn.text_to_send;
25302749
res->tokens = { tkn.tok };
25312750

2751+
res->echo = slot.params.echo;
2752+
res->prompt_text = slot.prompt_text;
2753+
res->is_first_chunk = (slot.n_decoded == 1);
2754+
25322755
res->n_decoded = slot.n_decoded;
25332756
res->n_prompt_tokens = slot.n_prompt_tokens;
25342757
res->post_sampling_probs = slot.params.post_sampling_probs;
@@ -2562,7 +2785,9 @@ struct server_context {
25622785
res->content = slot.generated_text;
25632786
res->tokens = std::move(slot.generated_tokens);
25642787
res->timings = slot.get_timings();
2565-
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
2788+
2789+
res->echo = slot.params.echo;
2790+
res->prompt = slot.params.echo ? slot.prompt_text : slot.prompt_tokens.detokenize(ctx, true);
25662791
res->response_fields = std::move(slot.params.response_fields);
25672792

25682793
res->truncated = slot.truncated;
@@ -2595,6 +2820,10 @@ struct server_context {
25952820
slot.generated_token_probs.begin(),
25962821
slot.generated_token_probs.end());
25972822
}
2823+
2824+
if (slot.params.echo && !slot.prompt_token_probs.empty()) {
2825+
res->prompt_probs_output = slot.prompt_token_probs;
2826+
}
25982827
}
25992828

26002829
res->generation_params = slot.params; // copy the parameters

tools/server/utils.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,6 @@ static json oaicompat_completion_params_parse(const json & body) {
553553
throw std::runtime_error("Only one completion choice is allowed");
554554
}
555555

556-
// Handle "echo" field
557-
if (json_value(body, "echo", false)) {
558-
throw std::runtime_error("Only no echo is supported");
559-
}
560-
561556
// Params supported by OAI but unsupported by llama.cpp
562557
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
563558
for (const auto & param : unsupported_params) {

0 commit comments

Comments
 (0)