Skip to content

Commit 0d6485f

Browse files
committed
wip [no ci]
1 parent 1011a51 commit 0d6485f

File tree

3 files changed

+51
-48
lines changed

3 files changed

+51
-48
lines changed

examples/server/server.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,17 @@ struct server_response {
494494
}
495495

496496
// Send a new result to a waiting id_task
497-
void send(server_task_result & result) {
497+
template<typename T>
498+
void send(T & result) {
499+
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
498500
SRV_DBG("sending result for task id = %d\n", result.id);
499501

500502
std::unique_lock<std::mutex> lock(mutex_results);
501503
for (const auto & id_task : waiting_task_ids) {
502504
if (result.id == id_task) {
503505
SRV_DBG("task id = %d pushed to result queue\n", result.id);
504506

505-
queue_results.push_back(std::make_unique<server_task_result>(result));
507+
queue_results.push_back(std::make_unique<T>(std::move(result)));
506508
condition_results.notify_all();
507509
return;
508510
}
@@ -1166,8 +1168,10 @@ struct server_context {
11661168

11671169
void send_partial_response(server_slot & slot, completion_token_output tkn) {
11681170
server_task_result_cmpl_partial res;
1169-
res.id = slot.id_task;
1170-
res.content = tkn.text_to_send;
1171+
res.id = slot.id_task;
1172+
res.n_decoded = slot.n_decoded;
1173+
res.n_prompt_tokens = slot.n_prompt_tokens;
1174+
res.content = tkn.text_to_send;
11711175

11721176
if (slot.params.sampling.n_probs > 0) {
11731177
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
@@ -1189,7 +1193,11 @@ struct server_context {
11891193
queue_results.send(res);
11901194
}
11911195

1192-
void send_final_response(const server_slot & slot) {
1196+
void send_final_response(server_slot & slot) {
1197+
if (slot.params.stream) {
1198+
return send_partial_response(slot, {0, "", {}});
1199+
}
1200+
11931201
server_task_result_cmpl_final res;
11941202
res.id = slot.id_task;
11951203
res.id_slot = slot.id;
@@ -1380,6 +1388,7 @@ struct server_context {
13801388
const std::unordered_set<int> & id_tasks,
13811389
const std::function<void(std::vector<T>&)> & result_handler,
13821390
const std::function<void(json)> & error_handler) {
1391+
static_assert(std::is_base_of<server_task_result, T>::value, "T must be derived from server_task_result");
13831392
std::vector<T> results(id_tasks.size());
13841393
for (size_t i = 0; i < id_tasks.size(); i++) {
13851394
task_result_ptr result_raw = queue_results.recv(id_tasks);
@@ -2815,17 +2824,18 @@ int main(int argc, char ** argv) {
28152824
if (!stream) {
28162825
ctx_server.receive_multi_results<server_task_result_cmpl_final>(task_ids, [&](std::vector<server_task_result_cmpl_final> & results) {
28172826
// multitask is never support in chat completion, there is only one result
2818-
json result_oai = format_final_response_oaicompat(data, results[0].to_json(), completion_id, /*.streaming =*/ false, verbose);
2827+
json result_oai = format_final_response_oaicompat(data, results[0], completion_id, /*.streaming =*/ false, verbose);
28192828
res_ok(res, result_oai);
28202829
}, [&](const json & error_data) {
28212830
res_error(res, error_data);
28222831
});
28232832

28242833
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
28252834
} else {
2826-
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
2835+
std::string model_name = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
2836+
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, model_name](size_t, httplib::DataSink & sink) {
28272837
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_cmpl_partial & result) -> bool {
2828-
std::vector<json> result_array = format_partial_response_oaicompat(result.to_json(), completion_id);
2838+
std::vector<json> result_array = format_partial_response_oaicompat(model_name, result, completion_id);
28292839
for (auto & event_data : result_array) {
28302840
if (event_data.empty()) {
28312841
continue; // skip the stop token

examples/server/server.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ struct server_task_result_cmpl_partial : server_task_result {
281281
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
282282
int index = 0;
283283
std::string content;
284+
int32_t n_decoded;
285+
int32_t n_prompt_tokens;
284286
stop_type stop = STOP_TYPE_NONE;
285287
std::vector<completion_token_output> probs_output;
286288
result_timings timings;

examples/server/utils.hpp

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,14 @@ static json oaicompat_completion_params_parse(
583583
return llama_params;
584584
}
585585

586-
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
587-
bool stopped_word = result.count("stopped_word") != 0;
588-
bool stopped_eos = json_value(result, "stopped_eos", false);
589-
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
590-
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
591-
std::string content = json_value(result, "content", std::string(""));
592-
586+
static json format_final_response_oaicompat(
587+
const json & request,
588+
server_task_result_cmpl_final & result,
589+
const std::string & completion_id,
590+
bool streaming = false,
591+
bool verbose = false) {
593592
std::string finish_reason = "length";
594-
if (stopped_word || stopped_eos) {
593+
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
595594
finish_reason = "stop";
596595
}
597596

@@ -601,7 +600,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
601600
{"delta", json::object()}}})
602601
: json::array({json{{"finish_reason", finish_reason},
603602
{"index", 0},
604-
{"message", json{{"content", content},
603+
{"message", json{{"content", result.content},
605604
{"role", "assistant"}}}}});
606605

607606
std::time_t t = std::time(0);
@@ -613,48 +612,42 @@ static json format_final_response_oaicompat(const json & request, const json & r
613612
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
614613
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
615614
{"usage", json {
616-
{"completion_tokens", num_tokens_predicted},
617-
{"prompt_tokens", num_prompt_tokens},
618-
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
615+
{"completion_tokens", result.n_decoded},
616+
{"prompt_tokens", result.n_prompt_tokens},
617+
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
619618
}},
620619
{"id", completion_id}
621620
};
622621

623622
// extra fields for debugging purposes
624623
if (verbose) {
625-
res["__verbose"] = result;
624+
res["__verbose"] = result.to_json();
626625
}
627626

628-
if (result.contains("completion_probabilities")) {
629-
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
630-
}
627+
// TODO: fix this
628+
// if (result.contains("completion_probabilities")) {
629+
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
630+
// }
631631

632-
if (result.contains("timings")) {
633-
res.push_back({"timings", json_value(result, "timings", json::object())});
632+
if (result.timings.prompt_n >= 0) {
633+
res.push_back({"timings", result.timings.to_json()});
634634
}
635635

636636
return res;
637637
}
638638

639639
// return value is vector as there is one case where we might need to generate two responses
640-
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
641-
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
642-
return std::vector<json>({result});
643-
}
644-
645-
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
646-
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
647-
648-
bool stopped_word = json_value(result, "stopped_word", false);
649-
bool stopped_eos = json_value(result, "stopped_eos", false);
650-
bool stopped_limit = json_value(result, "stopped_limit", false);
651-
std::string content = json_value(result, "content", std::string(""));
640+
static std::vector<json> format_partial_response_oaicompat(
641+
std::string modelname,
642+
server_task_result_cmpl_partial & result,
643+
const std::string & completion_id) {
644+
bool first = result.n_decoded == 0;
645+
std::string content = result.content;
652646

653647
std::string finish_reason;
654-
if (stopped_word || stopped_eos) {
648+
if (result.stop == STOP_TYPE_WORD || result.stop == STOP_TYPE_EOS) {
655649
finish_reason = "stop";
656-
}
657-
if (stopped_limit) {
650+
} else if (result.stop == STOP_TYPE_LIMIT) {
658651
finish_reason = "length";
659652
}
660653

@@ -724,17 +717,15 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
724717
{"object", "chat.completion.chunk"}
725718
};
726719

727-
if (result.contains("timings")) {
728-
ret.push_back({"timings", json_value(result, "timings", json::object())});
720+
if (result.timings.prompt_n >= 0) {
721+
ret.push_back({"timings", result.timings.to_json()});
729722
}
730723

731724
if (!finish_reason.empty()) {
732-
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
733-
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
734725
ret.push_back({"usage", json {
735-
{"completion_tokens", num_tokens_predicted},
736-
{"prompt_tokens", num_prompt_tokens},
737-
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
726+
{"completion_tokens", result.n_decoded},
727+
{"prompt_tokens", result.n_prompt_tokens},
728+
{"total_tokens", result.n_decoded + result.n_prompt_tokens}
738729
}});
739730
}
740731

0 commit comments

Comments
 (0)