@@ -64,7 +64,6 @@ enum server_task_type {
6464 SERVER_TASK_TYPE_SLOT_SAVE,
6565 SERVER_TASK_TYPE_SLOT_RESTORE,
6666 SERVER_TASK_TYPE_SLOT_ERASE,
67- SERVER_TASK_TYPE_SET_LORA,
6867};
6968
7069// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
@@ -91,6 +90,8 @@ struct slot_params {
9190 int64_t t_max_prompt_ms = -1 ; // TODO: implement
9291 int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
9392
93+ std::vector<common_lora_adapter_container> lora;
94+
9495 std::vector<std::string> antiprompt;
9596 std::vector<std::string> response_fields;
9697 bool timings_per_token = false ;
@@ -114,6 +115,11 @@ struct slot_params {
114115 samplers.emplace_back (common_sampler_type_to_str (sampler));
115116 }
116117
118+ json lora = json::array ();
119+ for (size_t i = 0 ; i < this ->lora .size (); ++i) {
120+ lora.push_back ({{" id" , i}, {" scale" , this ->lora [i].scale }});
121+ }
122+
117123 return json {
118124 {" n_predict" , n_predict}, // Server configured n_predict
119125 {" seed" , sampling.seed },
@@ -154,6 +160,7 @@ struct slot_params {
154160 {" speculative.p_min" , speculative.p_min },
155161 {" timings_per_token" , timings_per_token},
156162 {" post_sampling_probs" , post_sampling_probs},
163+ {" lora" , lora},
157164 };
158165 }
159166};
@@ -189,6 +196,7 @@ struct server_task {
189196 const llama_model * model,
190197 const llama_context * ctx,
191198 const common_params & params_base,
199+ const std::vector<common_lora_adapter_container> & base_lora,
192200 const json & data) {
193201 slot_params params;
194202
@@ -245,6 +253,16 @@ struct server_task {
245253 params.speculative .n_min = std::max (params.speculative .n_min , 2 );
246254 params.speculative .n_max = std::max (params.speculative .n_max , 0 );
247255
256+ if (data.contains (" lora" )) {
257+ if (data.at (" lora" ).is_array ()) {
258+ params.lora = parse_lora_request (base_lora, data.at (" lora" ));
259+ } else {
260+ throw std::runtime_error (" Error: 'lora' must be an array of objects with 'id' and 'scale' fields" );
261+ }
262+ } else {
263+ params.lora = base_lora;
264+ }
265+
248266 // TODO: add more sanity checks for the input parameters
249267
250268 if (params.sampling .penalty_last_n < -1 ) {
@@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result {
9891007 }
9901008};
9911009
992- struct server_task_result_apply_lora : server_task_result {
993- virtual json to_json () override {
994- return json {{ " success" , true }};
995- }
996- };
997-
9981010struct server_slot {
9991011 int id;
10001012 int id_task = -1 ;
@@ -1009,6 +1021,8 @@ struct server_slot {
10091021
10101022 common_speculative * spec = nullptr ;
10111023
1024+ std::vector<common_lora_adapter_container> lora;
1025+
10121026 // the index relative to completion multi-task request
10131027 size_t index = 0 ;
10141028
@@ -1091,7 +1105,8 @@ struct server_slot {
10911105 }
10921106
10931107 bool can_batch_with (server_slot & other_slot) {
1094- return is_non_causal () == other_slot.is_non_causal ();
1108+ return is_non_causal () == other_slot.is_non_causal ()
1109+ && are_lora_equal (lora, other_slot.lora );
10951110 }
10961111
10971112 bool has_budget (const common_params & global_params) {
@@ -1503,7 +1518,7 @@ struct server_context {
15031518
15041519 llama_model * model = nullptr ;
15051520 llama_context * ctx = nullptr ;
1506- std::vector<common_lora_adapter_container> loras ;
1521+ std::vector<common_lora_adapter_container> lora ;
15071522
15081523 llama_model * model_dft = nullptr ;
15091524 llama_context_params cparams_dft;
@@ -1570,7 +1585,7 @@ struct server_context {
15701585
15711586 model = llama_init.model ;
15721587 ctx = llama_init.context ;
1573- loras = llama_init.lora_adapters ;
1588+ lora = llama_init.lora_adapters ;
15741589
15751590 if (model == nullptr ) {
15761591 SRV_ERR (" failed to load model, '%s'\n " , params_base.model .c_str ());
@@ -1776,6 +1791,12 @@ struct server_context {
17761791 slot.params = std::move (task.params );
17771792 slot.prompt_tokens = std::move (task.prompt_tokens );
17781793
1794+ if (!are_lora_equal (task.params .lora , slot.lora )) {
1795+ // if lora is changed, we cannot reuse cached tokens
1796+ slot.cache_tokens .clear ();
1797+ slot.lora = std::move (task.params .lora );
1798+ }
1799+
17791800 SLT_DBG (slot, " launching slot : %s\n " , safe_json_to_str (slot.to_json ()).c_str ());
17801801
17811802 if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
@@ -2465,13 +2486,6 @@ struct server_context {
24652486 res->n_erased = n_erased;
24662487 queue_results.send (std::move (res));
24672488 } break ;
2468- case SERVER_TASK_TYPE_SET_LORA:
2469- {
2470- common_lora_adapters_apply (ctx, loras);
2471- auto res = std::make_unique<server_task_result_apply_lora>();
2472- res->id = task.id ;
2473- queue_results.send (std::move (res));
2474- } break ;
24752489 }
24762490 }
24772491
@@ -2808,8 +2822,12 @@ struct server_context {
28082822
28092823 SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.n_tokens );
28102824
2811- // make sure we're in the right embedding mode
2812- llama_set_embeddings (ctx, slot_batched && slot_batched->is_non_causal ());
2825+ if (slot_batched) {
2826+ // make sure we're in the right embedding mode
2827+ llama_set_embeddings (ctx, slot_batched->is_non_causal ());
2828+ // apply lora, only need to do it once per batch
2829+ common_lora_adapters_apply (ctx, slot_batched->lora );
2830+ }
28132831
28142832 // process the created batch of tokens
28152833 for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
@@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) {
35303548 task.index = i;
35313549
35323550 task.prompt_tokens = std::move (tokenized_prompts[i]);
3533- task.params = server_task::params_from_json_cmpl (ctx_server.model , ctx_server.ctx , ctx_server.params_base , data);
3551+ task.params = server_task::params_from_json_cmpl (
3552+ ctx_server.model ,
3553+ ctx_server.ctx ,
3554+ ctx_server.params_base ,
3555+ ctx_server.lora ,
3556+ data);
35343557 task.id_selected_slot = json_value (data, " id_slot" , -1 );
35353558
35363559 // OAI-compat
@@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) {
39443967
39453968 const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
39463969 json result = json::array ();
3947- for (size_t i = 0 ; i < ctx_server.loras .size (); ++i) {
3948- auto & lora = ctx_server.loras [i];
3970+ for (size_t i = 0 ; i < ctx_server.lora .size (); ++i) {
3971+ auto & lora = ctx_server.lora [i];
39493972 result.push_back ({
39503973 {" id" , i},
39513974 {" path" , lora.path },
@@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) {
39573980 };
39583981
39593982 const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
3960- const std::vector<json> body = json::parse (req.body );
3961- int max_idx = ctx_server.loras .size ();
3962-
3963- // clear existing value
3964- for (auto & lora : ctx_server.loras ) {
3965- lora.scale = 0 .0f ;
3966- }
3967-
3968- // set value
3969- for (auto entry : body) {
3970- int id = entry.at (" id" );
3971- float scale = entry.at (" scale" );
3972- if (0 <= id && id < max_idx) {
3973- ctx_server.loras [id].scale = scale;
3974- } else {
3975- throw std::runtime_error (" invalid adapter id" );
3976- }
3977- }
3978-
3979- server_task task (SERVER_TASK_TYPE_SET_LORA);
3980- task.id = ctx_server.queue_tasks .get_new_id ();
3981- ctx_server.queue_results .add_waiting_task_id (task.id );
3982- ctx_server.queue_tasks .post (task);
3983-
3984- server_task_result_ptr result = ctx_server.queue_results .recv (task.id );
3985- ctx_server.queue_results .remove_waiting_task_id (task.id );
3986-
3987- if (result->is_error ()) {
3988- res_error (res, result->to_json ());
3983+ const json body = json::parse (req.body );
3984+ if (!body.is_array ()) {
3985+ res_error (res, format_error_response (" Request body must be an array" , ERROR_TYPE_INVALID_REQUEST));
39893986 return ;
39903987 }
3991-
3992- GGML_ASSERT (dynamic_cast <server_task_result_apply_lora*>(result.get ()) != nullptr );
3993- res_ok (res, result->to_json ());
3988+ ctx_server.lora = parse_lora_request (ctx_server.lora , body);
3989+ res_ok (res, json{{" success" , true }});
39943990 };
39953991
39963992 //
0 commit comments