@@ -2407,7 +2407,7 @@ struct server_context {
24072407
24082408            params_dft.devices       = params_base.speculative .devices ;
24092409            params_dft.model         = params_base.speculative .model ;
2410-             params_dft.n_ctx         = params_base.speculative .n_ctx  == 0  ? params_base. n_ctx  / params_base. n_parallel  : params_base.speculative .n_ctx ;
2410+             params_dft.n_ctx         = params_base.speculative .n_ctx  == 0  ? llama_n_ctx_seq (ctx)  : params_base.speculative .n_ctx ;
24112411            params_dft.n_gpu_layers  = params_base.speculative .n_gpu_layers ;
24122412            params_dft.n_parallel    = 1 ;
24132413            params_dft.cache_type_k  = params_base.speculative .cache_type_k ;
@@ -2495,10 +2495,16 @@ struct server_context {
24952495    }
24962496
24972497    void  init () {
2498-         const  int32_t  n_ctx_slot = n_ctx / params_base.n_parallel ;
2499- 
25002498        SRV_INF (" initializing slots, n_slots = %d\n "  , params_base.n_parallel );
25012499
2500+         const  int  n_ctx_train = llama_model_n_ctx_train (model);
2501+ 
2502+         int  n_ctx_slot = llama_n_ctx_seq (ctx);
2503+         if  (n_ctx_slot > n_ctx_train) {
2504+             SRV_WRN (" the slot context (%d) exceeds the training context of the model (%d) - capping\n "  , n_ctx_slot, n_ctx_train);
2505+             n_ctx_slot = n_ctx_train;
2506+         }
2507+ 
25022508        for  (int  i = 0 ; i < params_base.n_parallel ; i++) {
25032509            server_slot slot;
25042510
@@ -2527,7 +2533,7 @@ struct server_context {
25272533                }
25282534            }
25292535
2530-             SLT_INF (slot, " new slot n_ctx_slot  = %d\n "  , slot.n_ctx );
2536+             SLT_INF (slot, " new slot, n_ctx  = %d\n "  , slot.n_ctx );
25312537
25322538            slot.callback_on_release  = [this ](int ) {
25332539                queue_tasks.pop_deferred_task ();
@@ -2699,6 +2705,39 @@ struct server_context {
26992705        return  ret;
27002706    }
27012707
2708+     //  return true if at least one slot has been purged
2709+     //  TODO: improve logic
2710+     //        - smarter decision which slot to purge (LRU or longest prompt?)
2711+     //        - move slot to level 2 cache instead of removing?
2712+     //        - instead of purging, try to store and resume later?
2713+     bool  try_purge_idle_slots () {
2714+         bool  res = false ;
2715+ 
2716+         if  (!params_base.kv_unified ) {
2717+             return  res;
2718+         }
2719+ 
2720+         for  (auto  & slot : slots) {
2721+             if  (slot.is_processing ()) {
2722+                 continue ;
2723+             }
2724+ 
2725+             if  (slot.prompt .n_tokens () > 0 ) {
2726+                 SRV_WRN (" purging slot %d with %zu tokens\n "  , slot.id , slot.prompt .tokens .size ());
2727+ 
2728+                 llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
2729+                 slot.prompt .tokens .clear ();
2730+ 
2731+                 res = true ;
2732+ 
2733+                 //  purge slots one by one
2734+                 break ;
2735+             }
2736+         }
2737+ 
2738+         return  res;
2739+     }
2740+ 
27022741    bool  launch_slot_with_task (server_slot & slot, server_task && task) {
27032742        slot.reset ();
27042743
@@ -3635,9 +3674,10 @@ struct server_context {
36353674        int32_t  n_batch  = llama_n_batch (ctx);
36363675        int32_t  n_ubatch = llama_n_ubatch (ctx);
36373676
3638-         //  next, batch any pending prompts without exceeding n_batch
3639-         float  alora_scale = -1 .0f ;
3677+         float   alora_scale       = -1 .0f ;
36403678        size_t  alora_disabled_id = 0 ;
3679+ 
3680+         //  next, batch any pending prompts without exceeding n_batch
36413681        if  (params_base.cont_batching  || batch.n_tokens  == 0 ) {
36423682            for  (auto  & slot : slots) {
36433683                //  check if we can batch this slot with the previous one
@@ -3914,8 +3954,11 @@ struct server_context {
39143954
39153955                    //  truncate any tokens that are beyond n_past for this slot
39163956                    const  llama_pos p0 = slot.prompt .tokens .pos_next ();
3957+ 
3958+                     SLT_INF (slot, " n_tokens = %d, memory_seq_rm [%d, end)\n "  , slot.prompt .n_tokens (), p0);
3959+ 
39173960                    if  (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , p0, -1 )) {
3918-                         SLT_WRN (slot, " failed to truncate tokens with position >= %d\n "  , p0);
3961+                         SLT_WRN (slot, " failed to truncate tokens with position >= %d - clearing the memory \n "  , p0);
39193962                        llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
39203963
39213964                        //  there is no common part left
@@ -3924,8 +3967,6 @@ struct server_context {
39243967                        slot.prompt .tokens .clear ();
39253968                    }
39263969
3927-                     SLT_INF (slot, " n_tokens = %d, memory_seq_rm [%d, end)\n "  , slot.prompt .n_tokens (), p0);
3928- 
39293970                    //  check if we should process the image
39303971                    if  (slot.prompt .n_tokens () < slot.task ->n_tokens () && input_tokens[slot.prompt .n_tokens ()] == LLAMA_TOKEN_NULL) {
39313972                        //  process the image
@@ -4126,6 +4167,8 @@ struct server_context {
41264167                    std::string err;
41274168
41284169                    if  (n_batch == 1  && ret == 1 ) {
4170+                         //  TODO: try to terminate only the largest active slot/sequence and continue with the rest
4171+                         //        need to remove the tokens from the current batch too
41294172                        err = " Context size has been exceeded."  ;
41304173                    }
41314174
@@ -4141,17 +4184,23 @@ struct server_context {
41414184                    //  TODO: handle ret == 2 (abort) when we start aborting
41424185
41434186                    if  (!err.empty ()) {
4144-                         SRV_ERR (" %s, i = %d, n_batch = %d, ret = %d\n "  , err.c_str (), i, n_batch, ret);
4187+                         SRV_ERR (" %s i = %d, n_batch = %d, ret = %d\n "  , err.c_str (), i, n_batch, ret);
4188+ 
41454189                        for  (auto  & slot : slots) {
4146-                             send_error (slot, err);
4147-                             slot.release ();
4190+                             if  (slot.is_processing ()) {
4191+                                 send_error (slot, err);
4192+                                 slot.release ();
4193+                             }
41484194                        }
4195+ 
41494196                        break ;
41504197                    }
41514198                }
41524199
41534200                //  retry with half the batch size to try to find a free slot in the KV cache
4154-                 n_batch /= 2 ;
4201+                 if  (!try_purge_idle_slots ()) {
4202+                     n_batch /= 2 ;
4203+                 }
41554204
41564205                SRV_WRN (" failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n "  , i, n_batch, ret);
41574206
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
43914440        return  1 ;
43924441    }
43934442
4443+     //  TODO: should we have a separate n_parallel parameter for the server?
4444+     //        https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
4445+     if  (params.n_parallel  == 1  && params.kv_unified  == false ) {
4446+         LOG_WRN (" %s: setting n_parallel = 4 and kv_unified = true\n "  , __func__);
4447+ 
4448+         params.n_parallel  = 4 ;
4449+         params.kv_unified  = true ;
4450+     }
4451+ 
43944452    common_init ();
43954453
43964454    //  struct that contains llama context and inference
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
49445002                //  Everything else, including multimodal completions.
49455003                inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
49465004            }
4947-             const  size_t  n_ctx_slot = ctx_server.n_ctx  / ctx_server. params_base . n_parallel ;
5005+             const  size_t  n_ctx_slot = ctx_server.slots . front (). n_ctx ;
49485006            tasks.reserve (inputs.size ());
49495007            for  (size_t  i = 0 ; i < inputs.size (); i++) {
49505008                auto  n_prompt_tokens = inputs[i].size ();
0 commit comments