diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d1c..be267c4405c8d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -112,6 +112,7 @@ struct slot_params { bool stream = true; bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt bool return_tokens = false; + bool echo = false; int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half @@ -160,6 +161,7 @@ struct slot_params { } return json { + {"echo", echo}, {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, @@ -265,6 +267,7 @@ struct server_task { params.stream = json_value(data, "stream", false); params.cache_prompt = json_value(data, "cache_prompt", true); params.return_tokens = json_value(data, "return_tokens", false); + params.echo = json_value(data, "echo", false); params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); params.n_indent = json_value(data, "n_indent", defaults.n_indent); params.n_keep = json_value(data, "n_keep", defaults.n_keep); @@ -674,6 +677,91 @@ struct completion_token_output { return out; } + static json oaicompat_probs_vector_to_json( + const std::vector & probs_out, + bool post_sampling_probs, + bool echo, + const std::vector & prompt_probs = {} + ) { + json out = json::object(); + + std::vector tokens; + std::vector all_probs; + + if (echo && !prompt_probs.empty()) { + all_probs.insert(all_probs.end(), prompt_probs.begin(), prompt_probs.end()); + } + + all_probs.insert(all_probs.end(), probs_out.begin(), probs_out.end()); + + tokens.reserve(all_probs.size()); + for (const auto & p : all_probs) { + std::string piece = p.text_to_send; + piece.resize(validate_utf8(piece)); + tokens.push_back(piece); + } + + int text_offset = 0; + std::vector text_offsets; + text_offsets.reserve(tokens.size()); + + int current_off = text_offset; + for (const auto & tok : tokens) { + text_offsets.push_back(current_off); + current_off += static_cast(tok.size()); + } + + std::vector> token_logprobs; + token_logprobs.reserve(all_probs.size()); + + std::vector>> top_logprobs; + top_logprobs.reserve(all_probs.size()); + + for (size_t i = 0; i < all_probs.size(); ++i) { + const auto & p = all_probs[i]; + + if (std::isinf(p.prob) && p.prob < 0) { + token_logprobs.push_back(std::nullopt); + top_logprobs.push_back(std::nullopt); + } else { + float logprob_value = p.prob; + if (!post_sampling_probs) { + logprob_value = p.prob; + } else { + logprob_value = p.prob > 0.0f ? std::log(p.prob) : -std::numeric_limits::infinity(); + } + + token_logprobs.push_back(std::optional(logprob_value)); + + std::unordered_map top_map; + for (const auto & cand : p.probs) { + std::string cand_txt = cand.txt; + cand_txt.resize(validate_utf8(cand_txt)); + + float cand_logprob; + if (!post_sampling_probs) { + cand_logprob = cand.prob; + } else { + cand_logprob = cand.prob > 0.0f ? std::log(cand.prob) : -std::numeric_limits::infinity(); + } + + top_map[cand_txt] = cand_logprob; + } + + top_logprobs.push_back(std::move(top_map)); + } + } + + out = json{ + {"text_offset", text_offsets}, + {"token_logprobs", token_logprobs}, + {"tokens", tokens}, + {"top_logprobs", top_logprobs} + }; + + return out; + } + static float logarithm(float x) { // nlohmann::json converts -inf to null, so we need to prevent that return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); @@ -697,6 +785,7 @@ struct server_task_result_cmpl_final : server_task_result { bool stream; result_timings timings; std::string prompt; + bool echo = false; bool truncated; int32_t n_decoded; @@ -708,6 +797,7 @@ struct server_task_result_cmpl_final : server_task_result { bool post_sampling_probs; std::vector probs_output; + std::vector prompt_probs_output; std::vector response_fields; slot_params generation_params; @@ -769,19 +859,26 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat() { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; + if (!stream && (probs_output.size() > 0 || (echo && prompt_probs_output.size() > 0))) { + logprobs = completion_token_output::oaicompat_probs_vector_to_json( + probs_output, + post_sampling_probs, + echo, + prompt_probs_output + ); } json finish_reason = "length"; if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { finish_reason = "stop"; } + std::string response_text = content; + if (echo && !stream) { + response_text = prompt + content; + } json res = json { {"choices", json::array({ json{ - {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"text", stream ? "" : response_text}, // in stream mode, content is already in last partial chunk {"index", index}, {"logprobs", logprobs}, {"finish_reason", finish_reason}, @@ -940,6 +1037,10 @@ struct server_task_result_cmpl_partial : server_task_result { std::string oaicompat_cmpl_id; std::vector oaicompat_msg_diffs; + bool echo = false; + std::string prompt_text; + bool is_first_chunk = false; + virtual int get_index() override { return index; } @@ -986,14 +1087,21 @@ struct server_task_result_cmpl_partial : server_task_result { std::time_t t = std::time(0); json logprobs = json(nullptr); // OAI default to null if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; + logprobs = completion_token_output::oaicompat_probs_vector_to_json( + std::vector{prob_output}, + post_sampling_probs, + echo + ); + } + + std::string response_text = content; + if (echo && is_first_chunk) { + response_text = prompt_text + content; } json res = json { {"choices", json::array({ json{ - {"text", content}, + {"text", response_text}, {"index", index}, {"logprobs", logprobs}, {"finish_reason", nullptr}, @@ -1321,6 +1429,8 @@ struct server_slot { // input prompt tokens server_tokens prompt_tokens; + std::string prompt_text; + std::vector prompt_token_probs; size_t last_nl_pos = 0; @@ -1368,6 +1478,7 @@ struct server_slot { SLT_DBG(*this, "%s", "\n"); n_prompt_tokens = 0; + prompt_text = ""; last_nl_pos = 0; generated_text = ""; has_new_line = false; @@ -1381,6 +1492,7 @@ struct server_slot { generated_tokens.clear(); generated_token_probs.clear(); + prompt_token_probs.clear(); chat_msg = {}; json_schema = json(); generated_tool_call_ids.clear(); @@ -2240,6 +2352,113 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + if (slot.params.echo) { + slot.prompt_text = slot.prompt_tokens.detokenize(ctx, true); + + if (slot.params.sampling.n_probs > 0 && slot.prompt_tokens.size() > 1 && slot.prompt_token_probs.empty()) { + slot.prompt_token_probs.reserve(slot.prompt_tokens.size()); + + llama_memory_clear(llama_get_memory(ctx), true); + + const int n_batch = llama_n_batch(ctx); + const int num_batches = (slot.prompt_tokens.size() + n_batch - 1) / n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + + std::vector all_logits; + if (num_batches > 1) { + all_logits.reserve(slot.prompt_tokens.size() * n_vocab); + } + + for (int batch_idx = 0; batch_idx < num_batches; ++batch_idx) { + const int batch_start = batch_idx * n_batch; + const int batch_size = std::min((int)slot.prompt_tokens.size() - batch_start, n_batch); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + for (int i = 0; i < batch_size; ++i) { + common_batch_add(batch, slot.prompt_tokens[batch_start + i], batch_start + i, {0}, true); + } + + if (llama_decode(ctx, batch) == 0) { + const float * batch_logits = llama_get_logits(ctx); + if (num_batches > 1) { + all_logits.insert(all_logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } + } else { + llama_batch_free(batch); + break; + } + llama_batch_free(batch); + } + + for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) { + completion_token_output prompt_token; + prompt_token.tok = slot.prompt_tokens[i]; + prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true); + + if (i == 0) { + prompt_token.prob = -std::numeric_limits::infinity(); + } else { + const float * logits = num_batches > 1 ? + all_logits.data() + (i - 1) * n_vocab : + llama_get_logits_ith(ctx, i - 1); + + if (logits != nullptr) { + float max_logit = logits[0]; + for (int j = 1; j < n_vocab; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + + double sum_exp = 0.0; + for (int j = 0; j < n_vocab; ++j) { + sum_exp += expf(logits[j] - max_logit); + } + + const float log_sum_exp = max_logit + logf(sum_exp); + prompt_token.prob = logits[slot.prompt_tokens[i]] - log_sum_exp; + + if (slot.params.sampling.n_probs > 0) { + std::vector> logits_id; + logits_id.reserve(n_vocab); + + for (int j = 0; j < n_vocab; j++) { + const float logprob = logits[j] - log_sum_exp; + logits_id.emplace_back(logprob, j); + } + + std::partial_sort(logits_id.begin(), + logits_id.begin() + std::min((size_t)slot.params.sampling.n_probs, logits_id.size()), + logits_id.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + prompt_token.probs.clear(); + size_t top_k = std::min(logits_id.size(), static_cast(slot.params.sampling.n_probs)); + for (size_t k = 0; k < top_k; ++k) { + completion_token_output::prob_info prob_info; + prob_info.tok = logits_id[k].second; + prob_info.prob = logits_id[k].first; + prob_info.txt = common_token_to_piece(ctx, logits_id[k].second, true); + prompt_token.probs.push_back(prob_info); + } + } + } else { + prompt_token.prob = -std::numeric_limits::infinity(); + } + } + + slot.prompt_token_probs.push_back(prompt_token); + } + } else { + for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) { + completion_token_output prompt_token; + prompt_token.tok = slot.prompt_tokens[i]; + prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true); + prompt_token.prob = -std::numeric_limits::infinity(); + slot.prompt_token_probs.push_back(prompt_token); + } + } + } + + if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens slot.cache_tokens.clear(); @@ -2529,6 +2748,10 @@ struct server_context { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; + res->echo = slot.params.echo; + res->prompt_text = slot.prompt_text; + res->is_first_chunk = (slot.n_decoded == 1); + res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens; res->post_sampling_probs = slot.params.post_sampling_probs; @@ -2562,7 +2785,9 @@ struct server_context { res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); + + res->echo = slot.params.echo; + res->prompt = slot.params.echo ? slot.prompt_text : slot.prompt_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2595,6 +2820,10 @@ struct server_context { slot.generated_token_probs.begin(), slot.generated_token_probs.end()); } + + if (slot.params.echo && !slot.prompt_token_probs.empty()) { + res->prompt_probs_output = slot.prompt_token_probs; + } } res->generation_params = slot.params; // copy the parameters diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f3dfc8225da4d..2abafe64849e0 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -553,11 +553,6 @@ static json oaicompat_completion_params_parse(const json & body) { throw std::runtime_error("Only one completion choice is allowed"); } - // Handle "echo" field - if (json_value(body, "echo", false)) { - throw std::runtime_error("Only no echo is supported"); - } - // Params supported by OAI but unsupported by llama.cpp static const std::vector unsupported_params { "best_of", "suffix" }; for (const auto & param : unsupported_params) {