Skip to content

Commit 7a7de2a

Browse files
committed
fully working model maganement
1 parent 016f8b4 commit 7a7de2a

File tree

4 files changed

+439
-1
lines changed

4 files changed

+439
-1
lines changed

tools/server/server-http.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,88 @@ void server_http_context::post(const std::string & path, server_http_context::ha
383383
});
384384
}
385385

386+
387+
//
388+
// server_http_client
389+
//
390+
391+
server_http_client::server_http_client(
392+
const std::string & method,
393+
const std::string & host,
394+
int port,
395+
const std::string & path,
396+
const std::map<std::string, std::string> & headers,
397+
const std::string & body,
398+
const std::function<bool()> should_stop) {
399+
// shared between reader and writer threads
400+
auto cli = std::make_shared<httplib::Client>(host, port);
401+
auto pipe = std::make_shared<pipe_t<msg_t>>();
402+
403+
// setup Client
404+
cli->set_connection_timeout(0, 200000); // 200 milliseconds
405+
this->status = 500; // to be overwritten upon response
406+
this->cleanup = [pipe]() {
407+
pipe->close_read();
408+
pipe->close_write();
409+
};
410+
411+
// wire up the receive end of the pipe
412+
this->next = [pipe, should_stop](std::string & out) -> bool {
413+
msg_t msg;
414+
bool has_next = pipe->read(msg, should_stop);
415+
if (!msg.data.empty()) {
416+
out = std::move(msg.data);
417+
}
418+
return has_next;
419+
};
420+
421+
// wire up the HTTP client
422+
// note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
423+
httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
424+
msg_t msg;
425+
msg.status = response.status;
426+
for (const auto & [key, value] : response.headers) {
427+
msg.headers[key] = value;
428+
}
429+
pipe->write(std::move(msg)); // send headers first
430+
return true;
431+
};
432+
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
433+
return pipe->write({{}, 0, std::string(data, data_length)}); // send data chunks
434+
};
435+
436+
// prepare the request to destination server
437+
httplib::Request req;
438+
{
439+
req.method = method;
440+
req.path = path;
441+
for (const auto & [key, value] : headers) {
442+
req.set_header(key, value);
443+
}
444+
req.body = body;
445+
req.response_handler = response_handler;
446+
req.content_receiver = content_receiver;
447+
}
448+
449+
// start the proxy thread
450+
SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
451+
this->thread = std::thread([cli, pipe, req]() {
452+
auto result = cli->send(std::move(req));
453+
if (result.error() != httplib::Error::Success) {
454+
auto err_str = httplib::to_string(result.error());
455+
SRV_ERR("http client error: %s\n", err_str.c_str());
456+
pipe->write({{}, 500, ""}); // header
457+
pipe->write({{}, 0, "proxy error: " + err_str}); // body
458+
}
459+
pipe->close_write(); // signal EOF to reader
460+
SRV_DBG("%s", "client request thread ended\n");
461+
});
462+
this->thread.detach();
463+
464+
// wait for the first chunk (headers)
465+
msg_t header;
466+
pipe->read(header, should_stop);
467+
SRV_DBG("%s", "received response headers\n");
468+
this->status = header.status;
469+
this->headers = header.headers;
470+
}

tools/server/server-http.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,79 @@ struct server_http_context {
7575
// for debugging
7676
std::string listening_address;
7777
};
78+
79+
80+
81+
#include <queue>
82+
#include <mutex>
83+
#include <mutex>
84+
#include <condition_variable>
85+
86+
struct server_http_client : server_http_res {
87+
std::function<void()> cleanup = nullptr;
88+
public:
89+
server_http_client(const std::string & method,
90+
const std::string & host,
91+
int port,
92+
const std::string & path,
93+
const std::map<std::string, std::string> & headers,
94+
const std::string & body,
95+
const std::function<bool()> should_stop);
96+
~server_http_client() {
97+
if (cleanup) {
98+
cleanup();
99+
}
100+
}
101+
private:
102+
std::thread thread;
103+
struct msg_t {
104+
std::map<std::string, std::string> headers;
105+
int status = 0;
106+
std::string data;
107+
};
108+
// simple implementation of a pipe
109+
template<typename T>
110+
struct pipe_t {
111+
std::mutex mutex;
112+
std::condition_variable cv;
113+
std::queue<T> queue;
114+
std::atomic<bool> writer_closed{false};
115+
std::atomic<bool> reader_closed{false};
116+
void close_write() {
117+
writer_closed.store(true);
118+
cv.notify_all();
119+
}
120+
void close_read() {
121+
reader_closed.store(true);
122+
cv.notify_all();
123+
}
124+
bool read(T & output, const std::function<bool()> & should_stop) {
125+
std::unique_lock<std::mutex> lk(mutex);
126+
constexpr auto poll_interval = std::chrono::milliseconds(500);
127+
while (true) {
128+
if (!queue.empty()) {
129+
output = std::move(queue.front());
130+
queue.pop();
131+
return true;
132+
}
133+
if (writer_closed.load()) {
134+
return false; // clean EOF
135+
}
136+
if (should_stop()) {
137+
close_read(); // signal broken pipe to writer
138+
return false; // cancelled / reader no longer alive
139+
}
140+
cv.wait_for(lk, poll_interval);
141+
}
142+
}
143+
bool write(T && data) {
144+
std::lock_guard<std::mutex> lk(mutex);
145+
if (reader_closed.load()) {
146+
return false; // broken pipe
147+
}
148+
queue.push(std::move(data));
149+
cv.notify_one();
150+
return true;
151+
}
152+
};
153+
};

tools/server/server.cpp

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5108,6 +5108,106 @@ struct server_routes {
51085108
return res;
51095109
};
51105110

5111+
//
5112+
// router server
5113+
//
5114+
char ** envp;
5115+
std::map<std::string, server_spawn_instance> map_model_to_port;
5116+
void maybe_load_it_why_not(std::string & custom_model) {
5117+
// HACKYYYY, but for demo purpose; we load the model if it's in the cached list
5118+
if (map_model_to_port.find(custom_model) != map_model_to_port.end()) {
5119+
return; // already loaded, do nothing
5120+
}
5121+
auto models = common_list_cached_models();
5122+
for (const auto & model : models) {
5123+
auto m = model.to_string();
5124+
if (m == custom_model) {
5125+
server_router_create_instance(envp, map_model_to_port, m);
5126+
std::this_thread::sleep_for(std::chrono::seconds(5)); // hacky wait for the process to be ready
5127+
return; // nice
5128+
}
5129+
}
5130+
}
5131+
std::string get_one_if_has_only_one(std::string & custom_model) {
5132+
// HACKYYYY, but for demo purpose; we get the only model if there's only one
5133+
if (map_model_to_port.size() == 1) {
5134+
return map_model_to_port.begin()->first;
5135+
}
5136+
return custom_model;
5137+
}
5138+
server_http_context::handler_t proxy_get = [this](const server_http_req & req) {
5139+
std::string method = "GET";
5140+
std::string model = req.get_param("model");
5141+
maybe_load_it_why_not(model);
5142+
model = get_one_if_has_only_one(model);
5143+
return handle_proxy(req, method, model);
5144+
};
5145+
server_http_context::handler_t proxy_post = [this](const server_http_req & req) {
5146+
std::string method = "POST";
5147+
json body = json::parse(req.body);
5148+
std::string model = json_value(body, "model", std::string());
5149+
maybe_load_it_why_not(model);
5150+
model = get_one_if_has_only_one(model);
5151+
return handle_proxy(req, method, model);
5152+
};
5153+
server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) {
5154+
if (map_model_to_port.find(model) == map_model_to_port.end()) {
5155+
auto res = std::make_unique<server_res_generator>(ctx_server);
5156+
res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST));
5157+
return server_http_res_ptr(std::move(res));
5158+
}
5159+
server_http_res_ptr res(new server_http_client(
5160+
method, params.hostname, map_model_to_port[model].port,
5161+
req.path, req.headers, req.body, req.should_stop
5162+
));
5163+
return res;
5164+
}
5165+
server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) {
5166+
auto res = std::make_unique<server_res_generator>(ctx_server);
5167+
json body = json::parse(req.body);
5168+
std::string model = json_value(body, "model", std::string());
5169+
int status = server_router_create_instance(envp, map_model_to_port, model);
5170+
if (status != 0) {
5171+
res->error(format_error_response("fail to start the process", ERROR_TYPE_SERVER));
5172+
return res;
5173+
}
5174+
res->ok({{"success", true}});
5175+
return res;
5176+
};
5177+
server_http_context::handler_t get_router_models = [this](const server_http_req &) {
5178+
auto res = std::make_unique<server_res_generator>(ctx_server);
5179+
json models_json = json::array();
5180+
auto models = common_list_cached_models();
5181+
for (const auto & model : models) {
5182+
auto model_name = model.to_string();
5183+
bool loaded = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety
5184+
models_json.push_back(json {
5185+
{"model", model_name},
5186+
{"name", model_name},
5187+
{"id", model_name},
5188+
// TODO: other fields...
5189+
{"status", {
5190+
{"value", loaded ? "loaded" : "unloaded"}
5191+
}},
5192+
});
5193+
}
5194+
res->ok({{"data", models_json}});
5195+
return res;
5196+
};
5197+
server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) {
5198+
auto res = std::make_unique<server_res_generator>(ctx_server);
5199+
json body = json::parse(req.body);
5200+
std::string model = json_value(body, "model", std::string());
5201+
model = get_one_if_has_only_one(model);
5202+
if (map_model_to_port.find(model) == map_model_to_port.end()) {
5203+
res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST));
5204+
return res;
5205+
}
5206+
server_router_kill_single(map_model_to_port, model);
5207+
res->ok({{"success", true}});
5208+
return res;
5209+
};
5210+
51115211
private:
51125212
std::unique_ptr<server_res_generator> handle_completions_impl(
51135213
server_task_type type,
@@ -5501,7 +5601,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
55015601
};
55025602
}
55035603

5504-
int main(int argc, char ** argv) {
5604+
int main(int argc, char ** argv, char ** envp) {
55055605
// own arguments required by this example
55065606
common_params params;
55075607

@@ -5549,6 +5649,34 @@ int main(int argc, char ** argv) {
55495649
// register API routes
55505650
server_routes routes(params, ctx_server, ctx_http);
55515651

5652+
// hacky, replace handlers with proxy handlers if this is a router server
5653+
bool is_router_server = params.model.path == DEFAULT_MODEL_PATH;
5654+
if (is_router_server) {
5655+
routes.envp = envp;
5656+
routes.get_props = routes.proxy_get;
5657+
routes.post_props = routes.proxy_post;
5658+
// routes.get_models = routes.proxy_get;
5659+
routes.post_completions = routes.proxy_post;
5660+
routes.post_completions_oai = routes.proxy_post;
5661+
routes.post_chat_completions = routes.proxy_post;
5662+
routes.post_infill = routes.proxy_post;
5663+
routes.post_embeddings = routes.proxy_post;
5664+
routes.post_embeddings_oai = routes.proxy_post;
5665+
routes.post_rerank = routes.proxy_post;
5666+
routes.post_tokenize = routes.proxy_post;
5667+
routes.post_detokenize = routes.proxy_post;
5668+
routes.post_apply_template = routes.proxy_post;
5669+
routes.get_lora_adapters = routes.proxy_get;
5670+
routes.post_lora_adapters = routes.proxy_post;
5671+
routes.get_slots = routes.proxy_get;
5672+
routes.post_slots = routes.proxy_post;
5673+
5674+
// custom routes for router
5675+
routes.get_models = routes.get_router_models;
5676+
ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load));
5677+
ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload));
5678+
}
5679+
55525680
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
55535681
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
55545682
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
@@ -5594,6 +5722,8 @@ int main(int argc, char ** argv) {
55945722
llama_backend_free();
55955723
};
55965724

5725+
if (!is_router_server) { // HACKY
5726+
55975727
// start the HTTP server before loading the model to be able to serve /health requests
55985728
if (!ctx_http.start()) {
55995729
clean_up();
@@ -5631,6 +5761,8 @@ int main(int argc, char ** argv) {
56315761
ctx_server.queue_tasks.terminate();
56325762
};
56335763

5764+
} // end of !is_router_server
5765+
56345766
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
56355767
struct sigaction sigint_action;
56365768
sigint_action.sa_handler = signal_handler;
@@ -5645,6 +5777,8 @@ int main(int argc, char ** argv) {
56455777
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
56465778
#endif
56475779

5780+
if (!is_router_server) { // HACKY
5781+
56485782
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
56495783
LOG_INF("%s: starting the main loop...\n", __func__);
56505784
// this call blocks the main thread until queue_tasks.terminate() is called
@@ -5655,6 +5789,19 @@ int main(int argc, char ** argv) {
56555789
ctx_http.thread.join();
56565790
}
56575791
llama_memory_breakdown_print(ctx_server.ctx);
5792+
} else {
5793+
shutdown_handler = [&](int) {
5794+
ctx_http.stop();
5795+
};
5796+
if (!ctx_http.start()) {
5797+
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
5798+
return 1;
5799+
}
5800+
ctx_http.is_ready.store(true);
5801+
ctx_http.thread.join(); // keep the main thread alive
5802+
// kill_all_instances(routes.map_model_to_port); // why this also kill the main instance?
5803+
LOG_INF("%s: server stopped\n", __func__);
5804+
} // end of !is_router_server
56585805

56595806
return 0;
56605807
}

0 commit comments

Comments
 (0)