@@ -3633,14 +3633,17 @@ int main(int argc, char ** argv) {
36333633 }
36343634
36353635 // request slots data using task queue
3636- server_task task (SERVER_TASK_TYPE_METRICS);
3637- task.id = ctx_server.queue_tasks .get_new_id ();
3638- ctx_server.queue_results .add_waiting_task_id (task.id );
3639- ctx_server.queue_tasks .post (std::move (task), true ); // high-priority task
3636+ int task_id = ctx_server.queue_tasks .get_new_id ();
3637+ {
3638+ server_task task (SERVER_TASK_TYPE_METRICS);
3639+ task.id = task_id;
3640+ ctx_server.queue_results .add_waiting_task_id (task_id);
3641+ ctx_server.queue_tasks .post (std::move (task), true ); // high-priority task
3642+ }
36403643
36413644 // get the result
3642- server_task_result_ptr result = ctx_server.queue_results .recv (task. id );
3643- ctx_server.queue_results .remove_waiting_task_id (task. id );
3645+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id );
3646+ ctx_server.queue_results .remove_waiting_task_id (task_id );
36443647
36453648 if (result->is_error ()) {
36463649 res_error (res, result->to_json ());
@@ -3669,16 +3672,17 @@ int main(int argc, char ** argv) {
36693672 }
36703673
36713674 // request slots data using task queue
3672- server_task task (SERVER_TASK_TYPE_METRICS);
3673- task.id = ctx_server.queue_tasks .get_new_id ();
3674- task.metrics_reset_bucket = true ;
3675-
3676- ctx_server.queue_results .add_waiting_task_id (task.id );
3677- ctx_server.queue_tasks .post (std::move (task), true ); // high-priority task
3675+ int task_id = ctx_server.queue_tasks .get_new_id ();
3676+ {
3677+ server_task task (SERVER_TASK_TYPE_METRICS);
3678+ task.id = task_id;
3679+ ctx_server.queue_results .add_waiting_task_id (task_id);
3680+ ctx_server.queue_tasks .post (std::move (task), true ); // high-priority task
3681+ }
36783682
36793683 // get the result
3680- server_task_result_ptr result = ctx_server.queue_results .recv (task. id );
3681- ctx_server.queue_results .remove_waiting_task_id (task. id );
3684+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id );
3685+ ctx_server.queue_results .remove_waiting_task_id (task_id );
36823686
36833687 if (result->is_error ()) {
36843688 res_error (res, result->to_json ());
@@ -3775,17 +3779,20 @@ int main(int argc, char ** argv) {
37753779 }
37763780 std::string filepath = params.slot_save_path + filename;
37773781
3778- server_task task (SERVER_TASK_TYPE_SLOT_SAVE);
3779- task.id = ctx_server.queue_tasks .get_new_id ();
3780- task.slot_action .slot_id = id_slot;
3781- task.slot_action .filename = filename;
3782- task.slot_action .filepath = filepath;
3782+ int task_id = ctx_server.queue_tasks .get_new_id ();
3783+ {
3784+ server_task task (SERVER_TASK_TYPE_SLOT_SAVE);
3785+ task.id = ctx_server.queue_tasks .get_new_id ();
3786+ task.slot_action .slot_id = id_slot;
3787+ task.slot_action .filename = filename;
3788+ task.slot_action .filepath = filepath;
37833789
3784- ctx_server.queue_results .add_waiting_task_id (task.id );
3785- ctx_server.queue_tasks .post (std::move (task));
3790+ ctx_server.queue_results .add_waiting_task_id (task_id);
3791+ ctx_server.queue_tasks .post (std::move (task));
3792+ }
37863793
3787- server_task_result_ptr result = ctx_server.queue_results .recv (task. id );
3788- ctx_server.queue_results .remove_waiting_task_id (task. id );
3794+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id );
3795+ ctx_server.queue_results .remove_waiting_task_id (task_id );
37893796
37903797 if (result->is_error ()) {
37913798 res_error (res, result->to_json ());
@@ -3804,17 +3811,20 @@ int main(int argc, char ** argv) {
38043811 }
38053812 std::string filepath = params.slot_save_path + filename;
38063813
3807- server_task task (SERVER_TASK_TYPE_SLOT_RESTORE);
3808- task.id = ctx_server.queue_tasks .get_new_id ();
3809- task.slot_action .slot_id = id_slot;
3810- task.slot_action .filename = filename;
3811- task.slot_action .filepath = filepath;
3814+ int task_id = ctx_server.queue_tasks .get_new_id ();
3815+ {
3816+ server_task task (SERVER_TASK_TYPE_SLOT_RESTORE);
3817+ task.id = ctx_server.queue_tasks .get_new_id ();
3818+ task.slot_action .slot_id = id_slot;
3819+ task.slot_action .filename = filename;
3820+ task.slot_action .filepath = filepath;
38123821
3813- ctx_server.queue_results .add_waiting_task_id (task.id );
3814- ctx_server.queue_tasks .post (std::move (task));
3822+ ctx_server.queue_results .add_waiting_task_id (task_id);
3823+ ctx_server.queue_tasks .post (std::move (task));
3824+ }
38153825
3816- server_task_result_ptr result = ctx_server.queue_results .recv (task. id );
3817- ctx_server.queue_results .remove_waiting_task_id (task. id );
3826+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id );
3827+ ctx_server.queue_results .remove_waiting_task_id (task_id );
38183828
38193829 if (result->is_error ()) {
38203830 res_error (res, result->to_json ());
@@ -3826,15 +3836,18 @@ int main(int argc, char ** argv) {
38263836 };
38273837
38283838 const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */ , httplib::Response & res, int id_slot) {
3829- server_task task (SERVER_TASK_TYPE_SLOT_ERASE);
3830- task.id = ctx_server.queue_tasks .get_new_id ();
3831- task.slot_action .slot_id = id_slot;
3839+ int task_id = ctx_server.queue_tasks .get_new_id ();
3840+ {
3841+ server_task task (SERVER_TASK_TYPE_SLOT_ERASE);
3842+ task.id = ctx_server.queue_tasks .get_new_id ();
3843+ task.slot_action .slot_id = id_slot;
38323844
3833- ctx_server.queue_results .add_waiting_task_id (task.id );
3834- ctx_server.queue_tasks .post (std::move (task));
3845+ ctx_server.queue_results .add_waiting_task_id (task_id);
3846+ ctx_server.queue_tasks .post (std::move (task));
3847+ }
38353848
3836- server_task_result_ptr result = ctx_server.queue_results .recv (task. id );
3837- ctx_server.queue_results .remove_waiting_task_id (task. id );
3849+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id );
3850+ ctx_server.queue_results .remove_waiting_task_id (task_id );
38383851
38393852 if (result->is_error ()) {
38403853 res_error (res, result->to_json ());
@@ -3938,45 +3951,48 @@ int main(int argc, char ** argv) {
39383951 }
39393952
39403953 auto completion_id = gen_chatcmplid ();
3941- std::vector<server_task> tasks;
3942-
3943- try {
3944- const auto & prompt = data.at (" prompt" );
3945- // TODO: this log can become very long, put it behind a flag or think about a more compact format
3946- // SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
3947-
3948- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.vocab , prompt, true , true );
3949- tasks.reserve (tokenized_prompts.size ());
3950- for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3951- server_task task = server_task (type);
3952-
3953- task.id = ctx_server.queue_tasks .get_new_id ();
3954- task.index = i;
3955-
3956- task.prompt_tokens = std::move (tokenized_prompts[i]);
3957- task.params = server_task::params_from_json_cmpl (
3958- ctx_server.ctx ,
3959- ctx_server.params_base ,
3960- data);
3961- task.id_selected_slot = json_value (data, " id_slot" , -1 );
3962-
3963- // OAI-compat
3964- task.params .oaicompat = oaicompat;
3965- task.params .oaicompat_cmpl_id = completion_id;
3966- // oaicompat_model is already populated by params_from_json_cmpl
3954+ std::unordered_set<int > task_ids;
3955+ {
3956+ std::vector<server_task> tasks;
39673957
3968- tasks.push_back (std::move (task));
3958+ try {
3959+ const auto & prompt = data.at (" prompt" );
3960+ // TODO: this log can become very long, put it behind a flag or think about a more compact format
3961+ // SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
3962+
3963+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.vocab , prompt, true , true );
3964+ tasks.reserve (tokenized_prompts.size ());
3965+ for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3966+ server_task task = server_task (type);
3967+
3968+ task.id = ctx_server.queue_tasks .get_new_id ();
3969+ task.index = i;
3970+
3971+ task.prompt_tokens = std::move (tokenized_prompts[i]);
3972+ task.params = server_task::params_from_json_cmpl (
3973+ ctx_server.ctx ,
3974+ ctx_server.params_base ,
3975+ data);
3976+ task.id_selected_slot = json_value (data, " id_slot" , -1 );
3977+
3978+ // OAI-compat
3979+ task.params .oaicompat = oaicompat;
3980+ task.params .oaicompat_cmpl_id = completion_id;
3981+ // oaicompat_model is already populated by params_from_json_cmpl
3982+
3983+ tasks.push_back (std::move (task));
3984+ }
3985+ } catch (const std::exception & e) {
3986+ res_error (res, format_error_response (e.what (), ERROR_TYPE_INVALID_REQUEST));
3987+ return ;
39693988 }
3970- } catch (const std::exception & e) {
3971- res_error (res, format_error_response (e.what (), ERROR_TYPE_INVALID_REQUEST));
3972- return ;
3973- }
39743989
3975- ctx_server.queue_results .add_waiting_tasks (tasks);
3976- ctx_server.queue_tasks .post (std::move (tasks));
3990+ task_ids = server_task::get_list_id (tasks);
3991+ ctx_server.queue_results .add_waiting_tasks (tasks);
3992+ ctx_server.queue_tasks .post (std::move (tasks));
3993+ }
39773994
39783995 bool stream = json_value (data, " stream" , false );
3979- const auto task_ids = server_task::get_list_id (tasks);
39803996
39813997 if (!stream) {
39823998 ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
@@ -4268,6 +4284,7 @@ int main(int argc, char ** argv) {
42684284 // create and queue the task
42694285 json responses = json::array ();
42704286 bool error = false ;
4287+ std::unordered_set<int > task_ids;
42714288 {
42724289 std::vector<server_task> tasks;
42734290 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
@@ -4283,24 +4300,23 @@ int main(int argc, char ** argv) {
42834300 tasks.push_back (std::move (task));
42844301 }
42854302
4303+ task_ids = server_task::get_list_id (tasks);
42864304 ctx_server.queue_results .add_waiting_tasks (tasks);
42874305 ctx_server.queue_tasks .post (std::move (tasks));
4306+ }
42884307
4289- // get the result
4290- std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
4308+ // get the result
4309+ ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
4310+ for (auto & res : results) {
4311+ GGML_ASSERT (dynamic_cast <server_task_result_embd*>(res.get ()) != nullptr );
4312+ responses.push_back (res->to_json ());
4313+ }
4314+ }, [&](const json & error_data) {
4315+ res_error (res, error_data);
4316+ error = true ;
4317+ }, req.is_connection_closed );
42914318
4292- ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
4293- for (auto & res : results) {
4294- GGML_ASSERT (dynamic_cast <server_task_result_embd*>(res.get ()) != nullptr );
4295- responses.push_back (res->to_json ());
4296- }
4297- }, [&](const json & error_data) {
4298- res_error (res, error_data);
4299- error = true ;
4300- }, req.is_connection_closed );
4301-
4302- ctx_server.queue_results .remove_waiting_task_ids (task_ids);
4303- }
4319+ ctx_server.queue_results .remove_waiting_task_ids (task_ids);
43044320
43054321 if (error) {
43064322 return ;
@@ -4367,6 +4383,7 @@ int main(int argc, char ** argv) {
43674383 // create and queue the task
43684384 json responses = json::array ();
43694385 bool error = false ;
4386+ std::unordered_set<int > task_ids;
43704387 {
43714388 std::vector<server_task> tasks;
43724389 std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server.vocab , documents, /* add_special */ false , true );
@@ -4379,23 +4396,21 @@ int main(int argc, char ** argv) {
43794396 tasks.push_back (std::move (task));
43804397 }
43814398
4399+ task_ids = server_task::get_list_id (tasks);
43824400 ctx_server.queue_results .add_waiting_tasks (tasks);
43834401 ctx_server.queue_tasks .post (std::move (tasks));
4384-
4385- // get the result
4386- std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
4387-
4388- ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
4389- for (auto & res : results) {
4390- GGML_ASSERT (dynamic_cast <server_task_result_rerank*>(res.get ()) != nullptr );
4391- responses.push_back (res->to_json ());
4392- }
4393- }, [&](const json & error_data) {
4394- res_error (res, error_data);
4395- error = true ;
4396- }, req.is_connection_closed );
43974402 }
43984403
4404+ ctx_server.receive_multi_results (task_ids, [&](std::vector<server_task_result_ptr> & results) {
4405+ for (auto & res : results) {
4406+ GGML_ASSERT (dynamic_cast <server_task_result_rerank*>(res.get ()) != nullptr );
4407+ responses.push_back (res->to_json ());
4408+ }
4409+ }, [&](const json & error_data) {
4410+ res_error (res, error_data);
4411+ error = true ;
4412+ }, req.is_connection_closed );
4413+
43994414 if (error) {
44004415 return ;
44014416 }
@@ -4431,14 +4446,19 @@ int main(int argc, char ** argv) {
44314446 res_error (res, format_error_response (" Request body must be an array" , ERROR_TYPE_INVALID_REQUEST));
44324447 return ;
44334448 }
4434- server_task task (SERVER_TASK_TYPE_SET_LORA);
4435- task.id = ctx_server.queue_tasks .get_new_id ();
4436- task.set_lora = parse_lora_request (ctx_server.params_base .lora_adapters , body);
4437- ctx_server.queue_results .add_waiting_task_id (task.id );
4438- ctx_server.queue_tasks .post (std::move (task));
44394449
4440- server_task_result_ptr result = ctx_server.queue_results .recv (task.id );
4441- ctx_server.queue_results .remove_waiting_task_id (task.id );
4450+ int task_id = ctx_server.queue_tasks .get_new_id ();
4451+ {
4452+ server_task task (SERVER_TASK_TYPE_SET_LORA);
4453+ task.id = ctx_server.queue_tasks .get_new_id ();
4454+ task.set_lora = parse_lora_request (ctx_server.params_base .lora_adapters , body);
4455+ ctx_server.queue_results .add_waiting_task_id (task_id);
4456+ ctx_server.queue_tasks .post (std::move (task));
4457+ }
4458+
4459+ // get the result
4460+ server_task_result_ptr result = ctx_server.queue_results .recv (task_id);
4461+ ctx_server.queue_results .remove_waiting_task_id (task_id);
44424462
44434463 if (result->is_error ()) {
44444464 res_error (res, result->to_json ());
0 commit comments