@@ -5118,36 +5118,39 @@ struct server_routes {
51185118 if (map_model_to_port.find (custom_model) != map_model_to_port.end ()) {
51195119 return ; // already loaded, do nothing
51205120 }
5121+ // TODO: maybe unload least recently used model if too many models are loaded?
5122+ auto wait_until_loaded = [this , custom_model]() {
5123+ while (true ) {
5124+ bool load_failed = map_model_to_port.find (custom_model) == map_model_to_port.end (); // model is deleted
5125+ bool is_loaded = !load_failed && map_model_to_port[custom_model].status == " loaded" ;
5126+ if (is_loaded || load_failed) {
5127+ return ;
5128+ }
5129+ std::this_thread::sleep_for (std::chrono::milliseconds (500 ));
5130+ }
5131+ };
51215132 auto models = common_list_cached_models ();
51225133 for (const auto & model : models) {
51235134 auto m = model.to_string ();
51245135 if (m == custom_model) {
5125- server_router_create_instance (envp, map_model_to_port, m);
5126- std::this_thread::sleep_for (std::chrono::seconds (5 )); // hacky wait for the process to be ready
5127- return ; // nice
5136+ server_router_create_instance (envp, map_model_to_port, m, params.port );
5137+ wait_until_loaded ();
5138+ SRV_INF (" model %s loaded on-demand\n " , custom_model.c_str ());
5139+ return ;
51285140 }
51295141 }
51305142 }
5131- std::string get_one_if_has_only_one (std::string & custom_model) {
5132- // HACKYYYY, but for demo purpose; we get the only model if there's only one
5133- if (map_model_to_port.size () == 1 ) {
5134- return map_model_to_port.begin ()->first ;
5135- }
5136- return custom_model;
5137- }
51385143 server_http_context::handler_t proxy_get = [this ](const server_http_req & req) {
51395144 std::string method = " GET" ;
51405145 std::string model = req.get_param (" model" );
51415146 maybe_load_it_why_not (model);
5142- model = get_one_if_has_only_one (model);
51435147 return handle_proxy (req, method, model);
51445148 };
51455149 server_http_context::handler_t proxy_post = [this ](const server_http_req & req) {
51465150 std::string method = " POST" ;
51475151 json body = json::parse (req.body );
51485152 std::string model = json_value (body, " model" , std::string ());
51495153 maybe_load_it_why_not (model);
5150- model = get_one_if_has_only_one (model);
51515154 return handle_proxy (req, method, model);
51525155 };
51535156 server_http_res_ptr handle_proxy (const server_http_req & req, std::string & method, std::string model) {
@@ -5166,28 +5169,41 @@ struct server_routes {
51665169 auto res = std::make_unique<server_res_generator>(ctx_server);
51675170 json body = json::parse (req.body );
51685171 std::string model = json_value (body, " model" , std::string ());
5169- int status = server_router_create_instance (envp, map_model_to_port, model);
5172+ int status = server_router_create_instance (envp, map_model_to_port, model, params. port );
51705173 if (status != 0 ) {
51715174 res->error (format_error_response (" fail to start the process" , ERROR_TYPE_SERVER));
51725175 return res;
51735176 }
51745177 res->ok ({{" success" , true }});
51755178 return res;
51765179 };
5180+ server_http_context::handler_t post_router_models_status = [this ](const server_http_req & req) {
5181+ auto res = std::make_unique<server_res_generator>(ctx_server);
5182+ json body = json::parse (req.body );
5183+ std::string model = json_value (body, " model" , std::string ());
5184+ std::string value = json_value (body, " value" , std::string ());
5185+ if (map_model_to_port.find (model) == map_model_to_port.end ()) {
5186+ res->error (format_error_response (" model parameter is invalid" , ERROR_TYPE_INVALID_REQUEST));
5187+ return res;
5188+ }
5189+ map_model_to_port[model].status = value;
5190+ res->ok ({{" success" , true }});
5191+ return res;
5192+ };
51775193 server_http_context::handler_t get_router_models = [this ](const server_http_req &) {
51785194 auto res = std::make_unique<server_res_generator>(ctx_server);
51795195 json models_json = json::array ();
51805196 auto models = common_list_cached_models ();
51815197 for (const auto & model : models) {
51825198 auto model_name = model.to_string ();
5183- bool loaded = map_model_to_port.find (model.to_string ()) != map_model_to_port.end (); // TODO: thread safety
5199+ bool found = map_model_to_port.find (model.to_string ()) != map_model_to_port.end (); // TODO: thread safety
51845200 models_json.push_back (json {
51855201 {" model" , model_name},
51865202 {" name" , model_name},
51875203 {" id" , model_name},
51885204 // TODO: other fields...
51895205 {" status" , {
5190- {" value" , loaded ? " loaded " : " unloaded" }
5206+ {" value" , found ? map_model_to_port[model_name]. status : " unloaded" }
51915207 }},
51925208 });
51935209 }
@@ -5198,7 +5214,6 @@ struct server_routes {
51985214 auto res = std::make_unique<server_res_generator>(ctx_server);
51995215 json body = json::parse (req.body );
52005216 std::string model = json_value (body, " model" , std::string ());
5201- model = get_one_if_has_only_one (model);
52025217 if (map_model_to_port.find (model) == map_model_to_port.end ()) {
52035218 res->error (format_error_response (" model parameter is invalid" , ERROR_TYPE_INVALID_REQUEST));
52045219 return res;
@@ -5673,8 +5688,9 @@ int main(int argc, char ** argv, char ** envp) {
56735688
56745689 // custom routes for router
56755690 routes.get_models = routes.get_router_models ;
5676- ctx_http.post (" /models/load" , ex_wrapper (routes.post_router_models_load ));
5691+ ctx_http.post (" /models/load" , ex_wrapper (routes.post_router_models_load ));
56775692 ctx_http.post (" /models/unload" , ex_wrapper (routes.post_router_models_unload ));
5693+ ctx_http.post (" /models/status" , ex_wrapper (routes.post_router_models_status ));
56785694 }
56795695
56805696 ctx_http.get (" /health" , ex_wrapper (routes.get_health )); // public endpoint (no API key check)
@@ -5779,6 +5795,21 @@ if (!is_router_server) { // HACKY
57795795
57805796if (!is_router_server) { // HACKY
57815797
5798+ // notify to main router if needed
5799+ char * router_port = std::getenv (" LLAMA_SERVER_ROUTER_PORT" );
5800+ if (router_port != nullptr ) {
5801+ SRV_INF (" %s: notifying to main router on port %s\n " , __func__, router_port);
5802+ server_http_client notify_router (
5803+ " POST" , params.hostname , std::atoi (router_port),
5804+ " /models/status" ,
5805+ { {" Content-Type" , " application/json" } },
5806+ json {{ " model" , params.model_alias }, { " value" , " loaded" }}.dump (),
5807+ []() { return false ; }
5808+ );
5809+ std::string dummy;
5810+ notify_router.next (dummy); // ignore the response
5811+ }
5812+
57825813 LOG_INF (" %s: server is listening on %s\n " , __func__, ctx_http.listening_address .c_str ());
57835814 LOG_INF (" %s: starting the main loop...\n " , __func__);
57845815 // this call blocks the main thread until queue_tasks.terminate() is called
0 commit comments