@@ -98,7 +98,7 @@ struct slot_params {
9898 int64_t t_max_prompt_ms = -1 ; // TODO: implement
9999 int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
100100
101- std::vector<common_lora_adapter_container > lora;
101+ std::vector<common_lora_adapter_info > lora;
102102
103103 std::vector<std::string> antiprompt;
104104 std::vector<std::string> response_fields;
@@ -198,15 +198,14 @@ struct server_task {
198198 bool metrics_reset_bucket = false ;
199199
200200 // used by SERVER_TASK_TYPE_SET_LORA
201- std::vector<common_lora_adapter_container > set_lora;
201+ std::vector<common_lora_adapter_info > set_lora;
202202
203203 server_task (server_task_type type) : type(type) {}
204204
205205 static slot_params params_from_json_cmpl (
206206 const llama_model * model,
207207 const llama_context * ctx,
208208 const common_params & params_base,
209- const std::vector<common_lora_adapter_container> & lora_base,
210209 const json & data) {
211210 slot_params params;
212211
@@ -265,12 +264,12 @@ struct server_task {
265264
266265 if (data.contains (" lora" )) {
267266 if (data.at (" lora" ).is_array ()) {
268- params.lora = parse_lora_request (lora_base , data.at (" lora" ));
267+ params.lora = parse_lora_request (params_base. lora_adapters , data.at (" lora" ));
269268 } else {
270269 throw std::runtime_error (" Error: 'lora' must be an array of objects with 'id' and 'scale' fields" );
271270 }
272271 } else {
273- params.lora = lora_base ;
272+ params.lora = params_base. lora_adapters ;
274273 }
275274
276275 // TODO: add more sanity checks for the input parameters
@@ -1132,7 +1131,7 @@ struct server_slot {
11321131
11331132 common_speculative * spec = nullptr ;
11341133
1135- std::vector<common_lora_adapter_container > lora;
1134+ std::vector<common_lora_adapter_info > lora;
11361135
11371136 // the index relative to completion multi-task request
11381137 size_t index = 0 ;
@@ -1627,11 +1626,15 @@ struct server_response {
16271626struct server_context {
16281627 common_params params_base;
16291628
1629+ // note: keep these alive - they determine the lifetime of the model, context, etc.
1630+ common_init_result llama_init;
1631+ common_init_result llama_init_dft;
1632+
16301633 llama_model * model = nullptr ;
16311634 llama_context * ctx = nullptr ;
1632- std::vector<common_lora_adapter_container> lora;
16331635
16341636 llama_model * model_dft = nullptr ;
1637+
16351638 llama_context_params cparams_dft;
16361639
16371640 llama_batch batch = {};
@@ -1655,21 +1658,6 @@ struct server_context {
16551658 float slot_prompt_similarity = 0 .0f ;
16561659
16571660 ~server_context () {
1658- if (ctx) {
1659- llama_free (ctx);
1660- ctx = nullptr ;
1661- }
1662-
1663- if (model) {
1664- llama_free_model (model);
1665- model = nullptr ;
1666- }
1667-
1668- if (model_dft) {
1669- llama_free_model (model_dft);
1670- model_dft = nullptr ;
1671- }
1672-
16731661 // Clear any sampling context
16741662 for (server_slot & slot : slots) {
16751663 common_sampler_free (slot.smpl );
@@ -1692,11 +1680,10 @@ struct server_context {
16921680
16931681 params_base = params;
16941682
1695- common_init_result llama_init = common_init_from_params (params_base);
1683+ llama_init = common_init_from_params (params_base);
16961684
1697- model = llama_init.model ;
1698- ctx = llama_init.context ;
1699- lora = llama_init.lora_adapters ;
1685+ model = llama_init.model .get ();
1686+ ctx = llama_init.context .get ();
17001687
17011688 if (model == nullptr ) {
17021689 SRV_ERR (" failed to load model, '%s'\n " , params_base.model .c_str ());
@@ -1719,35 +1706,29 @@ struct server_context {
17191706 params_dft.n_gpu_layers = params_base.speculative .n_gpu_layers ;
17201707 params_dft.n_parallel = 1 ;
17211708
1722- common_init_result llama_init_dft = common_init_from_params (params_dft);
1709+ llama_init_dft = common_init_from_params (params_dft);
17231710
1724- model_dft = llama_init_dft.model ;
1711+ model_dft = llama_init_dft.model . get () ;
17251712
17261713 if (model_dft == nullptr ) {
17271714 SRV_ERR (" failed to load draft model, '%s'\n " , params_base.speculative .model .c_str ());
17281715 return false ;
17291716 }
17301717
1731- if (!common_speculative_are_compatible (ctx, llama_init_dft.context )) {
1718+ if (!common_speculative_are_compatible (ctx, llama_init_dft.context . get () )) {
17321719 SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params_base.speculative .model .c_str (), params_base.model .c_str ());
17331720
1734- llama_free (llama_init_dft.context );
1735- llama_free_model (llama_init_dft.model );
1736-
17371721 return false ;
17381722 }
17391723
1740- const int n_ctx_dft = llama_n_ctx (llama_init_dft.context );
1724+ const int n_ctx_dft = llama_n_ctx (llama_init_dft.context . get () );
17411725
17421726 cparams_dft = common_context_params_to_llama (params_dft);
17431727 cparams_dft.n_batch = n_ctx_dft;
17441728
17451729 // force F16 KV cache for the draft model for extra performance
17461730 cparams_dft.type_k = GGML_TYPE_F16;
17471731 cparams_dft.type_v = GGML_TYPE_F16;
1748-
1749- // the context is not needed - we will create one for each slot
1750- llama_free (llama_init_dft.context );
17511732 }
17521733
17531734 return true ;
@@ -1898,7 +1879,7 @@ struct server_context {
18981879 if (!are_lora_equal (task.params .lora , slot.lora )) {
18991880 // if lora is changed, we cannot reuse cached tokens
19001881 slot.cache_tokens .clear ();
1901- slot.lora = std::move ( task.params .lora ) ;
1882+ slot.lora = task.params .lora ;
19021883 }
19031884
19041885 SLT_DBG (slot, " launching slot : %s\n " , safe_json_to_str (slot.to_json ()).c_str ());
@@ -2592,7 +2573,7 @@ struct server_context {
25922573 } break ;
25932574 case SERVER_TASK_TYPE_SET_LORA:
25942575 {
2595- lora = std::move (task.set_lora );
2576+ params_base. lora_adapters = std::move (task.set_lora );
25962577 auto res = std::make_unique<server_task_result_apply_lora>();
25972578 res->id = task.id ;
25982579 queue_results.send (std::move (res));
@@ -3671,7 +3652,6 @@ int main(int argc, char ** argv) {
36713652 ctx_server.model ,
36723653 ctx_server.ctx ,
36733654 ctx_server.params_base ,
3674- ctx_server.lora ,
36753655 data);
36763656 task.id_selected_slot = json_value (data, " id_slot" , -1 );
36773657
@@ -4098,8 +4078,9 @@ int main(int argc, char ** argv) {
40984078
40994079 const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
41004080 json result = json::array ();
4101- for (size_t i = 0 ; i < ctx_server.lora .size (); ++i) {
4102- auto & lora = ctx_server.lora [i];
4081+ const auto & loras = ctx_server.params_base .lora_adapters ;
4082+ for (size_t i = 0 ; i < loras.size (); ++i) {
4083+ auto & lora = loras[i];
41034084 result.push_back ({
41044085 {" id" , i},
41054086 {" path" , lora.path },
@@ -4118,7 +4099,7 @@ int main(int argc, char ** argv) {
41184099 }
41194100 server_task task (SERVER_TASK_TYPE_SET_LORA);
41204101 task.id = ctx_server.queue_tasks .get_new_id ();
4121- task.set_lora = parse_lora_request (ctx_server.lora , body);
4102+ task.set_lora = parse_lora_request (ctx_server.params_base . lora_adapters , body);
41224103 ctx_server.queue_results .add_waiting_task_id (task.id );
41234104 ctx_server.queue_tasks .post (task);
41244105
0 commit comments