Skip to content

Commit 21f8b73

Browse files
committed
fix code
1 parent c47c41c commit 21f8b73

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ struct server_slot {
177177
bool stopped_word = false;
178178
bool stopped_limit = false;
179179

180+
bool timing_per_token = false;
181+
180182
bool oaicompat = false;
181183

182184
std::string oaicompat_model;
@@ -882,6 +884,8 @@ struct server_context {
882884
slot.oaicompat_model = "";
883885
}
884886

887+
slot.timing_per_token = json_value(data, "timing_per_token", false);
888+
885889
slot.params.stream = json_value(data, "stream", false);
886890
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
887891
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
@@ -1269,7 +1273,6 @@ struct server_context {
12691273
{"n_keep", slot.params.n_keep},
12701274
{"n_discard", slot.params.n_discard},
12711275
{"ignore_eos", slot.params.sampling.ignore_eos},
1272-
{"timing_per_token", slot.params.sampling.timing_per_token},
12731276
{"stream", slot.params.stream},
12741277
//{"logit_bias", slot.params.sampling.logit_bias},
12751278
{"n_probs", slot.params.sampling.n_probs},
@@ -1280,6 +1283,7 @@ struct server_context {
12801283
{"speculative.n_max", slot.params.speculative.n_max},
12811284
{"speculative.n_min", slot.params.speculative.n_min},
12821285
{"speculative.p_min", slot.params.speculative.p_min},
1286+
{"timing_per_token", slot.timing_per_token},
12831287
};
12841288
}
12851289

@@ -1314,7 +1318,6 @@ struct server_context {
13141318
{"id_slot", slot.id},
13151319
{"multimodal", false},
13161320
{"index", slot.index},
1317-
{"timings", slot.get_formated_timings()},
13181321
};
13191322

13201323
if (slot.params.sampling.n_probs > 0) {
@@ -1338,6 +1341,10 @@ struct server_context {
13381341
res.data["model"] = slot.oaicompat_model;
13391342
}
13401343

1344+
if (slot.timing_per_token) {
1345+
res.data["timings"] = slot.get_formated_timings();
1346+
}
1347+
13411348
queue_results.send(res);
13421349
}
13431350

@@ -3002,25 +3009,24 @@ int main(int argc, char ** argv) {
30023009
ctx_server.queue_tasks.post(tasks);
30033010

30043011
bool stream = json_value(data, "stream", false);
3005-
bool timings = json_value(data, "timing_per_token", false);
30063012

30073013
const auto task_ids = server_task::get_list_id(tasks);
30083014
const auto completion_id = gen_chatcmplid();
30093015

30103016
if (!stream) {
30113017
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
30123018
// multitask is never support in chat completion, there is only one result
3013-
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings);
3019+
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
30143020
res_ok(res, result_oai);
30153021
}, [&](const json & error_data) {
30163022
res_error(res, error_data);
30173023
});
30183024

30193025
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
30203026
} else {
3021-
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) {
3027+
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
30223028
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
3023-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings);
3029+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
30243030
for (auto & event_data : result_array) {
30253031
if (event_data.empty()) {
30263032
continue; // skip the stop token

examples/server/utils.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
604604
return llama_params;
605605
}
606606

607-
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) {
607+
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
608608
bool stopped_word = result.count("stopped_word") != 0;
609609
bool stopped_eos = json_value(result, "stopped_eos", false);
610610
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@@ -650,15 +650,15 @@ static json format_final_response_oaicompat(const json & request, const json & r
650650
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
651651
}
652652

653-
if (timings) {
653+
if (result.contains("timings")) {
654654
res.push_back({"timings", json_value(result, "timings", json::object())});
655655
}
656656

657657
return res;
658658
}
659659

660660
// return value is vector as there is one case where we might need to generate two responses
661-
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) {
661+
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
662662
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
663663
return std::vector<json>({result});
664664
}
@@ -745,7 +745,7 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
745745
{"object", "chat.completion.chunk"}
746746
};
747747

748-
if (timings) {
748+
if (result.contains("timings")) {
749749
ret.push_back({"timings", json_value(result, "timings", json::object())});
750750
}
751751

0 commit comments

Comments
 (0)