@@ -5108,6 +5108,106 @@ struct server_routes {
51085108 return res;
51095109 };
51105110
5111+ //
5112+ // router server
5113+ //
5114+ char ** envp;
5115+ std::map<std::string, server_spawn_instance> map_model_to_port;
5116+ void maybe_load_it_why_not (std::string & custom_model) {
5117+ // HACKYYYY, but for demo purpose; we load the model if it's in the cached list
5118+ if (map_model_to_port.find (custom_model) != map_model_to_port.end ()) {
5119+ return ; // already loaded, do nothing
5120+ }
5121+ auto models = common_list_cached_models ();
5122+ for (const auto & model : models) {
5123+ auto m = model.to_string ();
5124+ if (m == custom_model) {
5125+ server_router_create_instance (envp, map_model_to_port, m);
5126+ std::this_thread::sleep_for (std::chrono::seconds (5 )); // hacky wait for the process to be ready
5127+ return ; // nice
5128+ }
5129+ }
5130+ }
5131+ std::string get_one_if_has_only_one (std::string & custom_model) {
5132+ // HACKYYYY, but for demo purpose; we get the only model if there's only one
5133+ if (map_model_to_port.size () == 1 ) {
5134+ return map_model_to_port.begin ()->first ;
5135+ }
5136+ return custom_model;
5137+ }
5138+ server_http_context::handler_t proxy_get = [this ](const server_http_req & req) {
5139+ std::string method = " GET" ;
5140+ std::string model = req.get_param (" model" );
5141+ maybe_load_it_why_not (model);
5142+ model = get_one_if_has_only_one (model);
5143+ return handle_proxy (req, method, model);
5144+ };
5145+ server_http_context::handler_t proxy_post = [this ](const server_http_req & req) {
5146+ std::string method = " POST" ;
5147+ json body = json::parse (req.body );
5148+ std::string model = json_value (body, " model" , std::string ());
5149+ maybe_load_it_why_not (model);
5150+ model = get_one_if_has_only_one (model);
5151+ return handle_proxy (req, method, model);
5152+ };
5153+ server_http_res_ptr handle_proxy (const server_http_req & req, std::string & method, std::string model) {
5154+ if (map_model_to_port.find (model) == map_model_to_port.end ()) {
5155+ auto res = std::make_unique<server_res_generator>(ctx_server);
5156+ res->error (format_error_response (" model parameter is invalid" , ERROR_TYPE_INVALID_REQUEST));
5157+ return server_http_res_ptr (std::move (res));
5158+ }
5159+ server_http_res_ptr res (new server_http_client (
5160+ method, params.hostname , map_model_to_port[model].port ,
5161+ req.path , req.headers , req.body , req.should_stop
5162+ ));
5163+ return res;
5164+ }
5165+ server_http_context::handler_t post_router_models_load = [this ](const server_http_req & req) {
5166+ auto res = std::make_unique<server_res_generator>(ctx_server);
5167+ json body = json::parse (req.body );
5168+ std::string model = json_value (body, " model" , std::string ());
5169+ int status = server_router_create_instance (envp, map_model_to_port, model);
5170+ if (status != 0 ) {
5171+ res->error (format_error_response (" fail to start the process" , ERROR_TYPE_SERVER));
5172+ return res;
5173+ }
5174+ res->ok ({{" success" , true }});
5175+ return res;
5176+ };
5177+ server_http_context::handler_t get_router_models = [this ](const server_http_req &) {
5178+ auto res = std::make_unique<server_res_generator>(ctx_server);
5179+ json models_json = json::array ();
5180+ auto models = common_list_cached_models ();
5181+ for (const auto & model : models) {
5182+ auto model_name = model.to_string ();
5183+ bool loaded = map_model_to_port.find (model.to_string ()) != map_model_to_port.end (); // TODO: thread safety
5184+ models_json.push_back (json {
5185+ {" model" , model_name},
5186+ {" name" , model_name},
5187+ {" id" , model_name},
5188+ // TODO: other fields...
5189+ {" status" , {
5190+ {" value" , loaded ? " loaded" : " unloaded" }
5191+ }},
5192+ });
5193+ }
5194+ res->ok ({{" data" , models_json}});
5195+ return res;
5196+ };
5197+ server_http_context::handler_t post_router_models_unload = [this ](const server_http_req & req) {
5198+ auto res = std::make_unique<server_res_generator>(ctx_server);
5199+ json body = json::parse (req.body );
5200+ std::string model = json_value (body, " model" , std::string ());
5201+ model = get_one_if_has_only_one (model);
5202+ if (map_model_to_port.find (model) == map_model_to_port.end ()) {
5203+ res->error (format_error_response (" model parameter is invalid" , ERROR_TYPE_INVALID_REQUEST));
5204+ return res;
5205+ }
5206+ server_router_kill_single (map_model_to_port, model);
5207+ res->ok ({{" success" , true }});
5208+ return res;
5209+ };
5210+
51115211private:
51125212 std::unique_ptr<server_res_generator> handle_completions_impl (
51135213 server_task_type type,
@@ -5501,7 +5601,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
55015601 };
55025602}
55035603
5504- int main (int argc, char ** argv) {
5604+ int main (int argc, char ** argv, char ** envp ) {
55055605 // own arguments required by this example
55065606 common_params params;
55075607
@@ -5549,6 +5649,34 @@ int main(int argc, char ** argv) {
55495649 // register API routes
55505650 server_routes routes (params, ctx_server, ctx_http);
55515651
5652+ // hacky, replace handlers with proxy handlers if this is a router server
5653+ bool is_router_server = params.model .path == DEFAULT_MODEL_PATH;
5654+ if (is_router_server) {
5655+ routes.envp = envp;
5656+ routes.get_props = routes.proxy_get ;
5657+ routes.post_props = routes.proxy_post ;
5658+ // routes.get_models = routes.proxy_get;
5659+ routes.post_completions = routes.proxy_post ;
5660+ routes.post_completions_oai = routes.proxy_post ;
5661+ routes.post_chat_completions = routes.proxy_post ;
5662+ routes.post_infill = routes.proxy_post ;
5663+ routes.post_embeddings = routes.proxy_post ;
5664+ routes.post_embeddings_oai = routes.proxy_post ;
5665+ routes.post_rerank = routes.proxy_post ;
5666+ routes.post_tokenize = routes.proxy_post ;
5667+ routes.post_detokenize = routes.proxy_post ;
5668+ routes.post_apply_template = routes.proxy_post ;
5669+ routes.get_lora_adapters = routes.proxy_get ;
5670+ routes.post_lora_adapters = routes.proxy_post ;
5671+ routes.get_slots = routes.proxy_get ;
5672+ routes.post_slots = routes.proxy_post ;
5673+
5674+ // custom routes for router
5675+ routes.get_models = routes.get_router_models ;
5676+ ctx_http.post (" /models/load" , ex_wrapper (routes.post_router_models_load ));
5677+ ctx_http.post (" /models/unload" , ex_wrapper (routes.post_router_models_unload ));
5678+ }
5679+
55525680 ctx_http.get (" /health" , ex_wrapper (routes.get_health )); // public endpoint (no API key check)
55535681 ctx_http.get (" /v1/health" , ex_wrapper (routes.get_health )); // public endpoint (no API key check)
55545682 ctx_http.get (" /metrics" , ex_wrapper (routes.get_metrics ));
@@ -5594,6 +5722,8 @@ int main(int argc, char ** argv) {
55945722 llama_backend_free ();
55955723 };
55965724
5725+ if (!is_router_server) { // HACKY
5726+
55975727 // start the HTTP server before loading the model to be able to serve /health requests
55985728 if (!ctx_http.start ()) {
55995729 clean_up ();
@@ -5631,6 +5761,8 @@ int main(int argc, char ** argv) {
56315761 ctx_server.queue_tasks .terminate ();
56325762 };
56335763
5764+ } // end of !is_router_server
5765+
56345766#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
56355767 struct sigaction sigint_action;
56365768 sigint_action.sa_handler = signal_handler;
@@ -5645,6 +5777,8 @@ int main(int argc, char ** argv) {
56455777 SetConsoleCtrlHandler (reinterpret_cast <PHANDLER_ROUTINE>(console_ctrl_handler), true );
56465778#endif
56475779
5780+ if (!is_router_server) { // HACKY
5781+
56485782 LOG_INF (" %s: server is listening on %s\n " , __func__, ctx_http.listening_address .c_str ());
56495783 LOG_INF (" %s: starting the main loop...\n " , __func__);
56505784 // this call blocks the main thread until queue_tasks.terminate() is called
@@ -5655,6 +5789,19 @@ int main(int argc, char ** argv) {
56555789 ctx_http.thread .join ();
56565790 }
56575791 llama_memory_breakdown_print (ctx_server.ctx );
5792+ } else {
5793+ shutdown_handler = [&](int ) {
5794+ ctx_http.stop ();
5795+ };
5796+ if (!ctx_http.start ()) {
5797+ LOG_ERR (" %s: exiting due to HTTP server error\n " , __func__);
5798+ return 1 ;
5799+ }
5800+ ctx_http.is_ready .store (true );
5801+ ctx_http.thread .join (); // keep the main thread alive
5802+ // kill_all_instances(routes.map_model_to_port); // why this also kill the main instance?
5803+ LOG_INF (" %s: server stopped\n " , __func__);
5804+ } // end of !is_router_server
56585805
56595806 return 0 ;
56605807}
0 commit comments