Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 72 additions & 58 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
SERVER_TASK_TYPE_NONE,
};

enum oaicompat_type {
Expand Down Expand Up @@ -480,6 +481,7 @@ struct result_timings {
struct server_task_result {
int id = -1;
int id_slot = -1;
server_task_type type;
virtual bool is_error() {
// only used by server_task_result_error
return false;
Expand All @@ -491,6 +493,7 @@ struct server_task_result {
virtual int get_index() {
return -1;
}
virtual server_task_type get_server_task_type() = 0;
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
Expand Down Expand Up @@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {

return ret;
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_COMPLETION;
}
};

struct server_task_result_cmpl_partial : server_task_result {
Expand Down Expand Up @@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {

return std::vector<json>({ret});
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_NONE;
}
};

struct server_task_result_embd : server_task_result {
Expand Down Expand Up @@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
{"tokens_evaluated", n_tokens},
};
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_EMBEDDING;
}
};

struct server_task_result_rerank : server_task_result {
Expand All @@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
{"tokens_evaluated", n_tokens},
};
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_RERANK;
}
};

// this function maybe used outside of server_task_result_error
Expand Down Expand Up @@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
virtual json to_json() override {
return format_error_response(err_msg, err_type);
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_NONE;
}
};

struct server_task_result_metrics : server_task_result {
Expand Down Expand Up @@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
{ "slots", slots_data },
};
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_METRICS;
}
};

struct server_task_result_slot_save_load : server_task_result {
Expand Down Expand Up @@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
};
}
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_SLOT_SAVE;
}
};

struct server_task_result_slot_erase : server_task_result {
Expand All @@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
{ "n_erased", n_erased },
};
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_SLOT_ERASE;
}
};

struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override {
return json {{ "success", true }};
}

server_task_type get_server_task_type() override {
return SERVER_TASK_TYPE_NONE;
}
};

struct server_slot {
Expand Down Expand Up @@ -2751,6 +2790,11 @@ struct server_context {
res->id = task.id;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_NONE:
{
// do nothing
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
}
}
}

Expand Down Expand Up @@ -3665,49 +3709,23 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};

const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;

server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
const auto handle_slot_impl = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req,
httplib::Response & res, int id_slot, server_task_type type) {
server_task task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.slot_action.slot_id = id_slot;
task.slot_action.filename = filename;
task.slot_action.filepath = filepath;

ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);

if (result->is_error()) {
res_error(res, result->to_json());
return;
}

res_ok(res, result->to_json());
};

const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
if (type != SERVER_TASK_TYPE_SLOT_ERASE) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
return;
}
std::string filepath = params.slot_save_path + filename;
task.slot_action.filename = filename;
task.slot_action.filepath = filepath;
}
std::string filepath = params.slot_save_path + filename;

server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
task.id = ctx_server.queue_tasks.get_new_id();
task.slot_action.slot_id = id_slot;
task.slot_action.filename = filename;
task.slot_action.filepath = filepath;

ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
Expand All @@ -3719,32 +3737,28 @@ int main(int argc, char ** argv) {
res_error(res, result->to_json());
return;
}

GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
GGML_ASSERT(result.get() != nullptr);
if (type == SERVER_TASK_TYPE_SLOT_RESTORE) {
GGML_ASSERT(result.get()->get_server_task_type() == SERVER_TASK_TYPE_SLOT_SAVE);
} else {
GGML_ASSERT(result.get()->get_server_task_type() == type);
}
res_ok(res, result->to_json());
};

const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
task.id = ctx_server.queue_tasks.get_new_id();
task.slot_action.slot_id = id_slot;

ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
const auto handle_slots_save = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) {
handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_SAVE);
};

if (result->is_error()) {
res_error(res, result->to_json());
return;
}
const auto handle_slots_restore = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) {
handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_RESTORE);
};

GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
res_ok(res, result->to_json());
const auto handle_slots_erase = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) {
handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_ERASE);
};

const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) {
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand Down
Loading