Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ struct common_params_sampling {
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

Expand Down
2 changes: 2 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ node index.js

`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.

`timing_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`

**Response format**

- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.
Expand Down
17 changes: 13 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,7 @@ struct server_context {
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"ignore_eos", slot.params.sampling.ignore_eos},
{"timing_per_token", slot.params.sampling.timing_per_token},
{"stream", slot.params.stream},
//{"logit_bias", slot.params.sampling.logit_bias},
{"n_probs", slot.params.sampling.n_probs},
Expand Down Expand Up @@ -1313,6 +1314,7 @@ struct server_context {
{"id_slot", slot.id},
{"multimodal", false},
{"index", slot.index},
{"timings", slot.get_formated_timings()},
};

if (slot.params.sampling.n_probs > 0) {
Expand Down Expand Up @@ -2274,12 +2276,17 @@ struct server_context {
common_sampler_accept(slot.smpl, id, true);

slot.n_decoded += 1;

const int64_t t_current = ggml_time_us();

if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}

slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;

completion_token_output result;
result.tok = id;

Expand Down Expand Up @@ -2995,23 +3002,25 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.post(tasks);

bool stream = json_value(data, "stream", false);
bool timings = json_value(data, "timing_per_token", false);

const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();

if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose, timings);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
});

ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, timings](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, timings);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
Expand Down
13 changes: 11 additions & 2 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ static json oaicompat_completion_params_parse(
return llama_params;
}

static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool timings = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
Expand Down Expand Up @@ -650,11 +650,15 @@ static json format_final_response_oaicompat(const json & request, const json & r
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}

if (timings) {
res.push_back({"timings", json_value(result, "timings", json::object())});
}

return res;
}

// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool timings = false) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
Expand Down Expand Up @@ -740,6 +744,11 @@ static std::vector<json> format_partial_response_oaicompat(const json & result,
{"model", modelname},
{"object", "chat.completion.chunk"}
};

if (timings) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
}

if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
Expand Down
Loading