Skip to content

server: implementation of v1/completions echo logprobs support #15189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 239 additions & 10 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ struct slot_params {
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool echo = false;

int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
Expand Down Expand Up @@ -160,6 +161,7 @@ struct slot_params {
}

return json {
{"echo", echo},
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
{"temperature", sampling.temp},
Expand Down Expand Up @@ -265,6 +267,7 @@ struct server_task {
params.stream = json_value(data, "stream", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.return_tokens = json_value(data, "return_tokens", false);
params.echo = json_value(data, "echo", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
Expand Down Expand Up @@ -674,6 +677,91 @@ struct completion_token_output {
return out;
}

static json oaicompat_probs_vector_to_json(
const std::vector<completion_token_output> & probs_out,
bool post_sampling_probs,
bool echo,
const std::vector<completion_token_output> & prompt_probs = {}
) {
json out = json::object();

std::vector<std::string> tokens;
std::vector<completion_token_output> all_probs;

if (echo && !prompt_probs.empty()) {
all_probs.insert(all_probs.end(), prompt_probs.begin(), prompt_probs.end());
}

all_probs.insert(all_probs.end(), probs_out.begin(), probs_out.end());

tokens.reserve(all_probs.size());
for (const auto & p : all_probs) {
std::string piece = p.text_to_send;
piece.resize(validate_utf8(piece));
tokens.push_back(piece);
}

int text_offset = 0;
std::vector<int> text_offsets;
text_offsets.reserve(tokens.size());

int current_off = text_offset;
for (const auto & tok : tokens) {
text_offsets.push_back(current_off);
current_off += static_cast<int>(tok.size());
}

std::vector<std::optional<float>> token_logprobs;
token_logprobs.reserve(all_probs.size());

std::vector<std::optional<std::unordered_map<std::string, float>>> top_logprobs;
top_logprobs.reserve(all_probs.size());

for (size_t i = 0; i < all_probs.size(); ++i) {
const auto & p = all_probs[i];

if (std::isinf(p.prob) && p.prob < 0) {
token_logprobs.push_back(std::nullopt);
top_logprobs.push_back(std::nullopt);
} else {
float logprob_value = p.prob;
if (!post_sampling_probs) {
logprob_value = p.prob;
} else {
logprob_value = p.prob > 0.0f ? std::log(p.prob) : -std::numeric_limits<float>::infinity();
}

token_logprobs.push_back(std::optional<float>(logprob_value));

std::unordered_map<std::string, float> top_map;
for (const auto & cand : p.probs) {
std::string cand_txt = cand.txt;
cand_txt.resize(validate_utf8(cand_txt));

float cand_logprob;
if (!post_sampling_probs) {
cand_logprob = cand.prob;
} else {
cand_logprob = cand.prob > 0.0f ? std::log(cand.prob) : -std::numeric_limits<float>::infinity();
}

top_map[cand_txt] = cand_logprob;
}

top_logprobs.push_back(std::move(top_map));
}
}

out = json{
{"text_offset", text_offsets},
{"token_logprobs", token_logprobs},
{"tokens", tokens},
{"top_logprobs", top_logprobs}
};

return out;
}

static float logarithm(float x) {
// nlohmann::json converts -inf to null, so we need to prevent that
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
Expand All @@ -697,6 +785,7 @@ struct server_task_result_cmpl_final : server_task_result {
bool stream;
result_timings timings;
std::string prompt;
bool echo = false;

bool truncated;
int32_t n_decoded;
Expand All @@ -708,6 +797,7 @@ struct server_task_result_cmpl_final : server_task_result {

bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<completion_token_output> prompt_probs_output;
std::vector<std::string> response_fields;

slot_params generation_params;
Expand Down Expand Up @@ -769,19 +859,26 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat() {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (!stream && probs_output.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
};
if (!stream && (probs_output.size() > 0 || (echo && prompt_probs_output.size() > 0))) {
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
probs_output,
post_sampling_probs,
echo,
prompt_probs_output
);
}
json finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
std::string response_text = content;
if (echo && !stream) {
response_text = prompt + content;
}
json res = json {
{"choices", json::array({
json{
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"text", stream ? "" : response_text}, // in stream mode, content is already in last partial chunk
{"index", index},
{"logprobs", logprobs},
{"finish_reason", finish_reason},
Expand Down Expand Up @@ -940,6 +1037,10 @@ struct server_task_result_cmpl_partial : server_task_result {
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;

bool echo = false;
std::string prompt_text;
bool is_first_chunk = false;

virtual int get_index() override {
return index;
}
Expand Down Expand Up @@ -986,14 +1087,21 @@ struct server_task_result_cmpl_partial : server_task_result {
std::time_t t = std::time(0);
json logprobs = json(nullptr); // OAI default to null
if (prob_output.probs.size() > 0) {
logprobs = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
logprobs = completion_token_output::oaicompat_probs_vector_to_json(
std::vector<completion_token_output>{prob_output},
post_sampling_probs,
echo
);
}

std::string response_text = content;
if (echo && is_first_chunk) {
response_text = prompt_text + content;
}
json res = json {
{"choices", json::array({
json{
{"text", content},
{"text", response_text},
{"index", index},
{"logprobs", logprobs},
{"finish_reason", nullptr},
Expand Down Expand Up @@ -1321,6 +1429,8 @@ struct server_slot {

// input prompt tokens
server_tokens prompt_tokens;
std::string prompt_text;
std::vector<completion_token_output> prompt_token_probs;

size_t last_nl_pos = 0;

Expand Down Expand Up @@ -1368,6 +1478,7 @@ struct server_slot {
SLT_DBG(*this, "%s", "\n");

n_prompt_tokens = 0;
prompt_text = "";
last_nl_pos = 0;
generated_text = "";
has_new_line = false;
Expand All @@ -1381,6 +1492,7 @@ struct server_slot {

generated_tokens.clear();
generated_token_probs.clear();
prompt_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
Expand Down Expand Up @@ -2240,6 +2352,113 @@ struct server_context {
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);

if (slot.params.echo) {
slot.prompt_text = slot.prompt_tokens.detokenize(ctx, true);

if (slot.params.sampling.n_probs > 0 && slot.prompt_tokens.size() > 1 && slot.prompt_token_probs.empty()) {
slot.prompt_token_probs.reserve(slot.prompt_tokens.size());

llama_memory_clear(llama_get_memory(ctx), true);

const int n_batch = llama_n_batch(ctx);
const int num_batches = (slot.prompt_tokens.size() + n_batch - 1) / n_batch;
const int n_vocab = llama_vocab_n_tokens(vocab);

std::vector<float> all_logits;
if (num_batches > 1) {
all_logits.reserve(slot.prompt_tokens.size() * n_vocab);
}

for (int batch_idx = 0; batch_idx < num_batches; ++batch_idx) {
const int batch_start = batch_idx * n_batch;
const int batch_size = std::min((int)slot.prompt_tokens.size() - batch_start, n_batch);

llama_batch batch = llama_batch_init(batch_size, 0, 1);
for (int i = 0; i < batch_size; ++i) {
common_batch_add(batch, slot.prompt_tokens[batch_start + i], batch_start + i, {0}, true);
}

if (llama_decode(ctx, batch) == 0) {
const float * batch_logits = llama_get_logits(ctx);
if (num_batches > 1) {
all_logits.insert(all_logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
} else {
llama_batch_free(batch);
break;
}
llama_batch_free(batch);
}

for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
completion_token_output prompt_token;
prompt_token.tok = slot.prompt_tokens[i];
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);

if (i == 0) {
prompt_token.prob = -std::numeric_limits<float>::infinity();
} else {
const float * logits = num_batches > 1 ?
all_logits.data() + (i - 1) * n_vocab :
llama_get_logits_ith(ctx, i - 1);

if (logits != nullptr) {
float max_logit = logits[0];
for (int j = 1; j < n_vocab; ++j) {
max_logit = std::max(max_logit, logits[j]);
}

double sum_exp = 0.0;
for (int j = 0; j < n_vocab; ++j) {
sum_exp += expf(logits[j] - max_logit);
}

const float log_sum_exp = max_logit + logf(sum_exp);
prompt_token.prob = logits[slot.prompt_tokens[i]] - log_sum_exp;

if (slot.params.sampling.n_probs > 0) {
std::vector<std::pair<float, llama_token>> logits_id;
logits_id.reserve(n_vocab);

for (int j = 0; j < n_vocab; j++) {
const float logprob = logits[j] - log_sum_exp;
logits_id.emplace_back(logprob, j);
}

std::partial_sort(logits_id.begin(),
logits_id.begin() + std::min((size_t)slot.params.sampling.n_probs, logits_id.size()),
logits_id.end(),
[](const auto& a, const auto& b) { return a.first > b.first; });

prompt_token.probs.clear();
size_t top_k = std::min(logits_id.size(), static_cast<size_t>(slot.params.sampling.n_probs));
for (size_t k = 0; k < top_k; ++k) {
completion_token_output::prob_info prob_info;
prob_info.tok = logits_id[k].second;
prob_info.prob = logits_id[k].first;
prob_info.txt = common_token_to_piece(ctx, logits_id[k].second, true);
prompt_token.probs.push_back(prob_info);
}
}
} else {
prompt_token.prob = -std::numeric_limits<float>::infinity();
}
}

slot.prompt_token_probs.push_back(prompt_token);
}
} else {
for (size_t i = 0; i < slot.prompt_tokens.size(); ++i) {
completion_token_output prompt_token;
prompt_token.tok = slot.prompt_tokens[i];
prompt_token.text_to_send = common_token_to_piece(ctx, slot.prompt_tokens[i], true);
prompt_token.prob = -std::numeric_limits<float>::infinity();
slot.prompt_token_probs.push_back(prompt_token);
}
}
}


if (!are_lora_equal(slot.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
Expand Down Expand Up @@ -2529,6 +2748,10 @@ struct server_context {
res->content = tkn.text_to_send;
res->tokens = { tkn.tok };

res->echo = slot.params.echo;
res->prompt_text = slot.prompt_text;
res->is_first_chunk = (slot.n_decoded == 1);

res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
Expand Down Expand Up @@ -2562,7 +2785,9 @@ struct server_context {
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = slot.prompt_tokens.detokenize(ctx, true);

res->echo = slot.params.echo;
res->prompt = slot.params.echo ? slot.prompt_text : slot.prompt_tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.params.response_fields);

res->truncated = slot.truncated;
Expand Down Expand Up @@ -2595,6 +2820,10 @@ struct server_context {
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
}

if (slot.params.echo && !slot.prompt_token_probs.empty()) {
res->prompt_probs_output = slot.prompt_token_probs;
}
}

res->generation_params = slot.params; // copy the parameters
Expand Down
5 changes: 0 additions & 5 deletions tools/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,6 @@ static json oaicompat_completion_params_parse(const json & body) {
throw std::runtime_error("Only one completion choice is allowed");
}

// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
}

// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
for (const auto & param : unsupported_params) {
Expand Down