Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 88 additions & 86 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ struct server_task_result {
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_partial
// only used by server_task_result_cmpl_*
return false;
}
virtual int get_index() {
Expand Down Expand Up @@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
return index;
}

virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}

virtual json to_json() override {
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
return oaicompat
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
: to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", content},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"id_slot", id_slot},
{"stop", true},
{"model", oaicompat_model},
Expand Down Expand Up @@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {

return res;
}

json to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}

json choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});

json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};

if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
}

return ret;
}
};

struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
std::string content;

bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;

stop_type stop = STOP_TYPE_NONE;

std::vector<completion_token_output> probs_output;
result_timings timings;

Expand All @@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
}

virtual bool is_stop() override {
return stop != STOP_TYPE_NONE;
return false; // in stream mode, partial responses are not considered stop
}

virtual json to_json() override {
if (oaicompat) {
return to_json_oaicompat();
}
bool is_stop = stop != STOP_TYPE_NONE;
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}

json to_json_non_oaicompat() {
// non-OAI-compat JSON
json res = json {
{"index", index},
{"content", content},
{"stop_type", stop_type_to_str(stop)},
{"stop", is_stop},
{"stop", false},
{"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
Expand All @@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result {
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
if (is_stop) {
res.push_back({"truncated", truncated});
}
return res;
}

json to_json_oaicompat() {
bool first = n_decoded == 0;

std::string finish_reason;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
} else if (stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}

std::time_t t = std::time(0);

json choices;

if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

return std::vector<json>({initial_ret, second_ret});
}
{"delta", json{{"role", "assistant"}}}}});
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"}};

return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}

json ret = json {
Expand All @@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
ret.push_back({"timings", timings.to_json()});
}

if (!finish_reason.empty()) {
ret.push_back({"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}});
}

return std::vector<json>({ret});
}
};
Expand Down Expand Up @@ -1888,12 +1895,9 @@ struct server_context {
res->index = slot.index;
res->content = tkn.text_to_send;

res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;

res->stop = slot.stop;

res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
Expand Down Expand Up @@ -1924,12 +1928,6 @@ struct server_context {
}

void send_final_response(server_slot & slot) {
if (slot.params.stream) {
// if in stream mode, send the last partial response
send_partial_response(slot, {0, "", {}});
return;
}

auto res = std::make_unique<server_task_result_cmpl_final>();
res->id = slot.id_task;
res->id_slot = slot.id;
Expand All @@ -1948,6 +1946,7 @@ struct server_context {
res->stop = slot.stop;

res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
Expand Down Expand Up @@ -2100,7 +2099,10 @@ struct server_context {
return;
}

GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;
Expand Down
6 changes: 6 additions & 0 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
})
content = ""
for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content)
else:
content += data["content"]
Expand Down
Loading