Skip to content

Commit 2a20068

Browse files
committed
improve maybe_load_it_why_not
1 parent 7a7de2a commit 2a20068

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

tools/server/server.cpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5118,36 +5118,39 @@ struct server_routes {
51185118
if (map_model_to_port.find(custom_model) != map_model_to_port.end()) {
51195119
return; // already loaded, do nothing
51205120
}
5121+
// TODO: maybe unload least recently used model if too many models are loaded?
5122+
auto wait_until_loaded = [this, custom_model]() {
5123+
while (true) {
5124+
bool load_failed = map_model_to_port.find(custom_model) == map_model_to_port.end(); // model is deleted
5125+
bool is_loaded = !load_failed && map_model_to_port[custom_model].status == "loaded";
5126+
if (is_loaded || load_failed) {
5127+
return;
5128+
}
5129+
std::this_thread::sleep_for(std::chrono::milliseconds(500));
5130+
}
5131+
};
51215132
auto models = common_list_cached_models();
51225133
for (const auto & model : models) {
51235134
auto m = model.to_string();
51245135
if (m == custom_model) {
5125-
server_router_create_instance(envp, map_model_to_port, m);
5126-
std::this_thread::sleep_for(std::chrono::seconds(5)); // hacky wait for the process to be ready
5127-
return; // nice
5136+
server_router_create_instance(envp, map_model_to_port, m, params.port);
5137+
wait_until_loaded();
5138+
SRV_INF("model %s loaded on-demand\n", custom_model.c_str());
5139+
return;
51285140
}
51295141
}
51305142
}
5131-
std::string get_one_if_has_only_one(std::string & custom_model) {
5132-
// HACKYYYY, but for demo purpose; we get the only model if there's only one
5133-
if (map_model_to_port.size() == 1) {
5134-
return map_model_to_port.begin()->first;
5135-
}
5136-
return custom_model;
5137-
}
51385143
server_http_context::handler_t proxy_get = [this](const server_http_req & req) {
51395144
std::string method = "GET";
51405145
std::string model = req.get_param("model");
51415146
maybe_load_it_why_not(model);
5142-
model = get_one_if_has_only_one(model);
51435147
return handle_proxy(req, method, model);
51445148
};
51455149
server_http_context::handler_t proxy_post = [this](const server_http_req & req) {
51465150
std::string method = "POST";
51475151
json body = json::parse(req.body);
51485152
std::string model = json_value(body, "model", std::string());
51495153
maybe_load_it_why_not(model);
5150-
model = get_one_if_has_only_one(model);
51515154
return handle_proxy(req, method, model);
51525155
};
51535156
server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) {
@@ -5166,28 +5169,41 @@ struct server_routes {
51665169
auto res = std::make_unique<server_res_generator>(ctx_server);
51675170
json body = json::parse(req.body);
51685171
std::string model = json_value(body, "model", std::string());
5169-
int status = server_router_create_instance(envp, map_model_to_port, model);
5172+
int status = server_router_create_instance(envp, map_model_to_port, model, params.port);
51705173
if (status != 0) {
51715174
res->error(format_error_response("fail to start the process", ERROR_TYPE_SERVER));
51725175
return res;
51735176
}
51745177
res->ok({{"success", true}});
51755178
return res;
51765179
};
5180+
server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) {
5181+
auto res = std::make_unique<server_res_generator>(ctx_server);
5182+
json body = json::parse(req.body);
5183+
std::string model = json_value(body, "model", std::string());
5184+
std::string value = json_value(body, "value", std::string());
5185+
if (map_model_to_port.find(model) == map_model_to_port.end()) {
5186+
res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST));
5187+
return res;
5188+
}
5189+
map_model_to_port[model].status = value;
5190+
res->ok({{"success", true}});
5191+
return res;
5192+
};
51775193
server_http_context::handler_t get_router_models = [this](const server_http_req &) {
51785194
auto res = std::make_unique<server_res_generator>(ctx_server);
51795195
json models_json = json::array();
51805196
auto models = common_list_cached_models();
51815197
for (const auto & model : models) {
51825198
auto model_name = model.to_string();
5183-
bool loaded = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety
5199+
bool found = map_model_to_port.find(model.to_string()) != map_model_to_port.end(); // TODO: thread safety
51845200
models_json.push_back(json {
51855201
{"model", model_name},
51865202
{"name", model_name},
51875203
{"id", model_name},
51885204
// TODO: other fields...
51895205
{"status", {
5190-
{"value", loaded ? "loaded" : "unloaded"}
5206+
{"value", found ? map_model_to_port[model_name].status : "unloaded"}
51915207
}},
51925208
});
51935209
}
@@ -5198,7 +5214,6 @@ struct server_routes {
51985214
auto res = std::make_unique<server_res_generator>(ctx_server);
51995215
json body = json::parse(req.body);
52005216
std::string model = json_value(body, "model", std::string());
5201-
model = get_one_if_has_only_one(model);
52025217
if (map_model_to_port.find(model) == map_model_to_port.end()) {
52035218
res->error(format_error_response("model parameter is invalid", ERROR_TYPE_INVALID_REQUEST));
52045219
return res;
@@ -5673,8 +5688,9 @@ int main(int argc, char ** argv, char ** envp) {
56735688

56745689
// custom routes for router
56755690
routes.get_models = routes.get_router_models;
5676-
ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load));
5691+
ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load));
56775692
ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload));
5693+
ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status));
56785694
}
56795695

56805696
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
@@ -5779,6 +5795,21 @@ if (!is_router_server) { // HACKY
57795795

57805796
if (!is_router_server) { // HACKY
57815797

5798+
// notify to main router if needed
5799+
char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
5800+
if (router_port != nullptr) {
5801+
SRV_INF("%s: notifying to main router on port %s\n", __func__, router_port);
5802+
server_http_client notify_router(
5803+
"POST", params.hostname, std::atoi(router_port),
5804+
"/models/status",
5805+
{ {"Content-Type", "application/json"} },
5806+
json {{ "model", params.model_alias }, { "value", "loaded" }}.dump(),
5807+
[]() { return false; }
5808+
);
5809+
std::string dummy;
5810+
notify_router.next(dummy); // ignore the response
5811+
}
5812+
57825813
LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
57835814
LOG_INF("%s: starting the main loop...\n", __func__);
57845815
// this call blocks the main thread until queue_tasks.terminate() is called

tools/server/utils.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,9 +1615,10 @@ struct server_spawn_instance {
16151615
pid_t pid = 0;
16161616
int port = 0;
16171617
std::thread th;
1618+
std::string status = "loading"; // "loading", "loaded"
16181619
};
16191620

1620-
inline int server_router_create_instance(char ** envp, std::map<std::string, server_spawn_instance> & mapping, const std::string & hf_model) {
1621+
inline int server_router_create_instance(char ** envp, std::map<std::string, server_spawn_instance> & mapping, const std::string & hf_model, int router_port) {
16211622
server_spawn_instance inst;
16221623
inst.port = rand() % 10000 + 20000; // random port between 20000 and 29999
16231624

@@ -1635,6 +1636,8 @@ inline int server_router_create_instance(char ** envp, std::map<std::string, ser
16351636
arg_strs.push_back(path);
16361637
arg_strs.push_back("-hf");
16371638
arg_strs.push_back(hf_model);
1639+
arg_strs.push_back("--alias");
1640+
arg_strs.push_back(hf_model);
16381641
arg_strs.push_back("--port");
16391642
arg_strs.push_back(std::to_string(inst.port));
16401643

@@ -1645,7 +1648,22 @@ inline int server_router_create_instance(char ** envp, std::map<std::string, ser
16451648
}
16461649
child_argv.push_back(nullptr);
16471650

1648-
if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), envp) != 0) {
1651+
// clone envp while adding LLAMA_SERVER_ROUTER_PORT
1652+
std::vector<std::string> child_envs;
1653+
std::vector<char *> child_envp;
1654+
{
1655+
for (char ** e = envp; *e != nullptr; ++e) {
1656+
child_envs.emplace_back(*e);
1657+
}
1658+
child_envs.emplace_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(router_port));
1659+
child_envp.reserve(child_envs.size() + 1);
1660+
for (auto & s : child_envs) {
1661+
child_envp.push_back(const_cast<char *>(s.c_str()));
1662+
}
1663+
child_envp.push_back(nullptr);
1664+
}
1665+
1666+
if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), child_envp.data()) != 0) {
16491667
perror("posix_spawn");
16501668
exit(1); // for testing only
16511669
} else {

0 commit comments

Comments
 (0)