Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)

# clean up generated files in pre-build step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note here, we should add a check in /scripts/xxd.cmake to see if the file need to be re-generated or not. I will do that in another PR.

Copy link
Member Author

@ggerganov ggerganov Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. You mentioned that the /slots endpoint is also broken. I haven't looked at it yet. Maybe we can apply any additional fixes in this PR before merging? Feel free to push directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup I fixed it in 01da1ed

I also fix a problem with cpp wrapper llama_get_chat_template because it returns null terminator in the final json:
Screenshot 2024-12-07 at 16 31 46

foreach(asset ${PUBLIC_ASSETS})
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND "${CMAKE_COMMAND}" -E remove -f "${output}"
)
endforeach()

target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})

if (LLAMA_SERVER_SSL)
Expand Down
75 changes: 39 additions & 36 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,23 @@ struct slot_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;

// params only used in to_json()
int32_t n_ctx;
uint32_t seed_cur;
bool can_speculative;

// OAI-compat fields
bool verbose = false;
bool oaicompat = false;
bool oaicompat_chat = true;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;

json to_json() {
json to_json() const {
std::vector<std::string> samplers;
samplers.reserve(sampling.samplers.size());
for (const auto & sampler : sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}

return json {
{"n_ctx", n_ctx},
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
{"temperature", sampling.temp},
{"dynatemp_range", sampling.dynatemp_range},
{"dynatemp_exponent", sampling.dynatemp_exponent},
Expand Down Expand Up @@ -177,7 +172,6 @@ struct slot_params {
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
{"samplers", samplers},
{"speculative", can_speculative},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
{"speculative.p_min", speculative.p_min},
Expand Down Expand Up @@ -483,12 +477,6 @@ struct server_task_result_cmpl_partial : server_task_result {
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}

Comment on lines -486 to -491
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes #10694

choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
Expand Down Expand Up @@ -722,6 +710,7 @@ struct server_slot {

llama_batch batch_spec;

llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;

common_speculative * spec = nullptr;
Expand Down Expand Up @@ -906,6 +895,27 @@ struct server_slot {
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
}

json to_json() const {
return json {
{"id", id},
{"id_task", id_task},
{"n_ctx", n_ctx},
{"speculative", can_speculate()},
{"is_processing", is_processing()},
{"params", params.to_json()},
{"prompt", common_detokenize(ctx, prompt_tokens)},
{"next_token",
{
{"has_next_token", has_next_token},
{"has_new_line", has_new_line},
{"n_remain", n_remaining},
{"n_decoded", n_decoded},
{"stopping_word", stopping_word},
}
},
};
}
};

struct server_metrics {
Expand Down Expand Up @@ -1338,6 +1348,7 @@ struct server_context {
server_slot slot;

slot.id = i;
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;

Expand Down Expand Up @@ -1370,8 +1381,7 @@ struct server_context {
slots.push_back(slot);
}

default_generation_settings_for_props = slots[0].params.to_json();
default_generation_settings_for_props["seed"] = -1;
default_generation_settings_for_props = slots[0].to_json();

// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
Expand Down Expand Up @@ -1848,17 +1858,18 @@ struct server_context {
queue_results.send(std::move(res));
}

void send_partial_response(server_slot & slot, completion_token_output tkn) {
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto res = std::make_unique<server_task_result_cmpl_partial>();
res->id = slot.id_task;
res->index = slot.index;
res->content = tkn.text_to_send;

res->id = slot.id_task;
res->index = slot.index;
res->content = tkn.text_to_send;

res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;

res->stop = slot.stop;
res->stop = slot.stop;

res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
Expand All @@ -1869,6 +1880,7 @@ struct server_context {
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);

const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());

Expand All @@ -1891,7 +1903,8 @@ struct server_context {
void send_final_response(server_slot & slot) {
if (slot.params.stream) {
// if in stream mode, send the last partial response
return send_partial_response(slot, {0, "", {}});
send_partial_response(slot, {0, "", {}});
return;
}

auto res = std::make_unique<server_task_result_cmpl_final>();
Expand Down Expand Up @@ -2012,6 +2025,7 @@ struct server_context {
std::vector<server_task> tasks;
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());

server_task task;
task.id = queue_tasks.get_new_id();
task.inf_type = inf_type;
Expand Down Expand Up @@ -2205,18 +2219,7 @@ struct server_context {
int n_processing_slots = 0;

for (server_slot & slot : slots) {
json slot_data = slot.params.to_json();
slot_data["id"] = slot.id;
slot_data["id_task"] = slot.id_task;
slot_data["is_processing"] = slot.is_processing();
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
slot_data["next_token"] = {
{"has_next_token", slot.has_next_token},
{"has_new_line", slot.has_new_line},
{"n_remain", slot.n_remaining},
{"n_decoded", slot.n_decoded},
{"stopping_word", slot.stopping_word},
};
json slot_data = slot.to_json();

if (slot.is_processing()) {
n_processing_slots++;
Expand Down Expand Up @@ -3003,11 +3006,11 @@ int main(int argc, char ** argv) {
res.status = 200;
};

svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
std::string message;
try {
std::rethrow_exception(ep);
} catch (std::exception & e) {
} catch (const std::exception & e) {
message = e.what();
} catch (...) {
message = "Unknown Exception";
Expand Down
Loading