Skip to content

Commit fb10521

Browse files
committed
add timings
1 parent 2c96bd2 commit fb10521

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,7 @@ struct server_context {
13131313
{"id_slot", slot.id},
13141314
{"multimodal", false},
13151315
{"index", slot.index},
1316+
{"timings", slot.get_formated_timings()},
13161317
};
13171318

13181319
if (slot.params.sampling.n_probs > 0) {
@@ -2274,12 +2275,17 @@ struct server_context {
22742275
common_sampler_accept(slot.smpl, id, true);
22752276

22762277
slot.n_decoded += 1;
2278+
2279+
const int64_t t_current = ggml_time_us();
2280+
22772281
if (slot.n_decoded == 1) {
2278-
slot.t_start_generation = ggml_time_us();
2282+
slot.t_start_generation = t_current;
22792283
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
22802284
metrics.on_prompt_eval(slot);
22812285
}
22822286

2287+
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
2288+
22832289
completion_token_output result;
22842290
result.tok = id;
22852291

@@ -2995,23 +3001,25 @@ int main(int argc, char ** argv) {
29953001
ctx_server.queue_tasks.post(tasks);
29963002

29973003
bool stream = json_value(data, "stream", false);
3004+
bool timings = json_value(data, "timing_per_token", false);
3005+
29983006
const auto task_ids = server_task::get_list_id(tasks);
29993007
const auto completion_id = gen_chatcmplid();
30003008

30013009
if (!stream) {
30023010
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
30033011
// multitask is never support in chat completion, there is only one result
3004-
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
3012+
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings);
30053013
res_ok(res, result_oai);
30063014
}, [&](const json & error_data) {
30073015
res_error(res, error_data);
30083016
});
30093017

30103018
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
30113019
} else {
3012-
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3020+
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) {
30133021
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
3014-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
3022+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings);
30153023
for (auto & event_data : result_array) {
30163024
if (event_data.empty()) {
30173025
continue; // skip the stop token

examples/server/utils.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
604604
return llama_params;
605605
}
606606

607-
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
607+
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) {
608608
bool stopped_word = result.count("stopped_word") != 0;
609609
bool stopped_eos = json_value(result, "stopped_eos", false);
610610
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@@ -650,11 +650,15 @@ static json format_final_response_oaicompat(const json & request, const json & r
650650
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
651651
}
652652

653+
if (timings) {
654+
res.push_back({"timings", json_value(result, "timings", json::object())});
655+
}
656+
653657
return res;
654658
}
655659

656660
// return value is vector as there is one case where we might need to generate two responses
657-
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
661+
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) {
658662
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
659663
return std::vector<json>({result});
660664
}
@@ -740,6 +744,11 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
740744
{"model", modelname},
741745
{"object", "chat.completion.chunk"}
742746
};
747+
748+
if (timings) {
749+
ret.push_back({"timings", json_value(result, "timings", json::object())});
750+
}
751+
743752
if (!finish_reason.empty()) {
744753
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
745754
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);

0 commit comments

Comments
 (0)