diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1ebcb50854d31..329492e29d6f7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -68,6 +68,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, SERVER_TASK_TYPE_SET_LORA, + SERVER_TASK_TYPE_NONE, }; enum oaicompat_type { @@ -480,6 +481,7 @@ struct result_timings { struct server_task_result { int id = -1; int id_slot = -1; + server_task_type type; virtual bool is_error() { // only used by server_task_result_error return false; @@ -491,6 +493,7 @@ struct server_task_result { virtual int get_index() { return -1; } + virtual server_task_type get_server_task_type() = 0; virtual json to_json() = 0; virtual ~server_task_result() = default; }; @@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result { return ret; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_COMPLETION; + } }; struct server_task_result_cmpl_partial : server_task_result { @@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result { return std::vector({ret}); } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_NONE; + } }; struct server_task_result_embd : server_task_result { @@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result { {"tokens_evaluated", n_tokens}, }; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_EMBEDDING; + } }; struct server_task_result_rerank : server_task_result { @@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result { {"tokens_evaluated", n_tokens}, }; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_RERANK; + } }; // this function maybe used outside of server_task_result_error @@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result { virtual json to_json() override { return format_error_response(err_msg, err_type); } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_NONE; + } }; struct server_task_result_metrics : server_task_result { @@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result { { "slots", slots_data }, }; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_METRICS; + } }; struct server_task_result_slot_save_load : server_task_result { @@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result { }; } } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_SLOT_SAVE; + } }; struct server_task_result_slot_erase : server_task_result { @@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result { { "n_erased", n_erased }, }; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_SLOT_ERASE; + } }; struct server_task_result_apply_lora : server_task_result { virtual json to_json() override { return json {{ "success", true }}; } + + server_task_type get_server_task_type() override { + return SERVER_TASK_TYPE_NONE; + } }; struct server_slot { @@ -2751,6 +2790,11 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_NONE: + { + // do nothing + GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n"); + } } } @@ -3665,49 +3709,23 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + const auto handle_slot_impl = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, + httplib::Response & res, int id_slot, server_task_type type) { + server_task task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - server_task_result_ptr result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result->is_error()) { - res_error(res, result->to_json()); - return; - } - - res_ok(res, result->to_json()); - }; - - const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; + if (type != SERVER_TASK_TYPE_SLOT_ERASE) { + json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return; + } + std::string filepath = params.slot_save_path + filename; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; } - std::string filepath = params.slot_save_path + filename; - - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = ctx_server.queue_tasks.get_new_id(); - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task); @@ -3719,32 +3737,28 @@ int main(int argc, char ** argv) { res_error(res, result->to_json()); return; } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + GGML_ASSERT(result.get() != nullptr); + if (type == SERVER_TASK_TYPE_SLOT_RESTORE) { + GGML_ASSERT(result.get()->get_server_task_type() == SERVER_TASK_TYPE_SLOT_SAVE); + } else { + GGML_ASSERT(result.get()->get_server_task_type() == type); + } res_ok(res, result->to_json()); }; - const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = ctx_server.queue_tasks.get_new_id(); - task.slot_action.slot_id = id_slot; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - server_task_result_ptr result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); + const auto handle_slots_save = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) { + handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_SAVE); + }; - if (result->is_error()) { - res_error(res, result->to_json()); - return; - } + const auto handle_slots_restore = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) { + handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_RESTORE); + }; - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); + const auto handle_slots_erase = [&handle_slot_impl](const httplib::Request & req, httplib::Response & res, int id_slot) { + handle_slot_impl(req, res, id_slot, SERVER_TASK_TYPE_SLOT_ERASE); }; - const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return;