@@ -2407,7 +2407,7 @@ struct server_context {
24072407
24082408 params_dft.devices = params_base.speculative .devices ;
24092409 params_dft.model = params_base.speculative .model ;
2410- params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? params_base. n_ctx / params_base. n_parallel : params_base.speculative .n_ctx ;
2410+ params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? llama_n_ctx_seq (ctx) : params_base.speculative .n_ctx ;
24112411 params_dft.n_gpu_layers = params_base.speculative .n_gpu_layers ;
24122412 params_dft.n_parallel = 1 ;
24132413 params_dft.cache_type_k = params_base.speculative .cache_type_k ;
@@ -2495,10 +2495,16 @@ struct server_context {
24952495 }
24962496
24972497 void init () {
2498- const int32_t n_ctx_slot = n_ctx / params_base.n_parallel ;
2499-
25002498 SRV_INF (" initializing slots, n_slots = %d\n " , params_base.n_parallel );
25012499
2500+ const int n_ctx_train = llama_model_n_ctx_train (model);
2501+
2502+ int n_ctx_slot = llama_n_ctx_seq (ctx);
2503+ if (n_ctx_slot > n_ctx_train) {
2504+ SRV_WRN (" the slot context (%d) exceeds the training context of the model (%d) - capping\n " , n_ctx_slot, n_ctx_train);
2505+ n_ctx_slot = n_ctx_train;
2506+ }
2507+
25022508 for (int i = 0 ; i < params_base.n_parallel ; i++) {
25032509 server_slot slot;
25042510
@@ -2527,7 +2533,7 @@ struct server_context {
25272533 }
25282534 }
25292535
2530- SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
2536+ SLT_INF (slot, " new slot, n_ctx = %d\n " , slot.n_ctx );
25312537
25322538 slot.callback_on_release = [this ](int ) {
25332539 queue_tasks.pop_deferred_task ();
@@ -2699,6 +2705,39 @@ struct server_context {
26992705 return ret;
27002706 }
27012707
2708+ // return true if at least one slot has been purged
2709+ // TODO: improve logic
2710+ // - smarter decision which slot to purge (LRU or longest prompt?)
2711+ // - move slot to level 2 cache instead of removing?
2712+ // - instead of purging, try to store and resume later?
2713+ bool try_purge_idle_slots () {
2714+ bool res = false ;
2715+
2716+ if (!params_base.kv_unified ) {
2717+ return res;
2718+ }
2719+
2720+ for (auto & slot : slots) {
2721+ if (slot.is_processing ()) {
2722+ continue ;
2723+ }
2724+
2725+ if (slot.prompt .n_tokens () > 0 ) {
2726+ SRV_WRN (" purging slot %d with %zu tokens\n " , slot.id , slot.prompt .tokens .size ());
2727+
2728+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
2729+ slot.prompt .tokens .clear ();
2730+
2731+ res = true ;
2732+
2733+ // purge slots one by one
2734+ break ;
2735+ }
2736+ }
2737+
2738+ return res;
2739+ }
2740+
27022741 bool launch_slot_with_task (server_slot & slot, server_task && task) {
27032742 slot.reset ();
27042743
@@ -3635,9 +3674,10 @@ struct server_context {
36353674 int32_t n_batch = llama_n_batch (ctx);
36363675 int32_t n_ubatch = llama_n_ubatch (ctx);
36373676
3638- // next, batch any pending prompts without exceeding n_batch
3639- float alora_scale = -1 .0f ;
3677+ float alora_scale = -1 .0f ;
36403678 size_t alora_disabled_id = 0 ;
3679+
3680+ // next, batch any pending prompts without exceeding n_batch
36413681 if (params_base.cont_batching || batch.n_tokens == 0 ) {
36423682 for (auto & slot : slots) {
36433683 // check if we can batch this slot with the previous one
@@ -3914,8 +3954,11 @@ struct server_context {
39143954
39153955 // truncate any tokens that are beyond n_past for this slot
39163956 const llama_pos p0 = slot.prompt .tokens .pos_next ();
3957+
3958+ SLT_INF (slot, " n_tokens = %d, memory_seq_rm [%d, end)\n " , slot.prompt .n_tokens (), p0);
3959+
39173960 if (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , p0, -1 )) {
3918- SLT_WRN (slot, " failed to truncate tokens with position >= %d\n " , p0);
3961+ SLT_WRN (slot, " failed to truncate tokens with position >= %d - clearing the memory \n " , p0);
39193962 llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
39203963
39213964 // there is no common part left
@@ -3924,8 +3967,6 @@ struct server_context {
39243967 slot.prompt .tokens .clear ();
39253968 }
39263969
3927- SLT_INF (slot, " n_tokens = %d, memory_seq_rm [%d, end)\n " , slot.prompt .n_tokens (), p0);
3928-
39293970 // check if we should process the image
39303971 if (slot.prompt .n_tokens () < slot.task ->n_tokens () && input_tokens[slot.prompt .n_tokens ()] == LLAMA_TOKEN_NULL) {
39313972 // process the image
@@ -4126,6 +4167,8 @@ struct server_context {
41264167 std::string err;
41274168
41284169 if (n_batch == 1 && ret == 1 ) {
4170+ // TODO: try to terminate only the largest active slot/sequence and continue with the rest
4171+ // need to remove the tokens from the current batch too
41294172 err = " Context size has been exceeded." ;
41304173 }
41314174
@@ -4141,17 +4184,23 @@ struct server_context {
41414184 // TODO: handle ret == 2 (abort) when we start aborting
41424185
41434186 if (!err.empty ()) {
4144- SRV_ERR (" %s, i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4187+ SRV_ERR (" %s i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4188+
41454189 for (auto & slot : slots) {
4146- send_error (slot, err);
4147- slot.release ();
4190+ if (slot.is_processing ()) {
4191+ send_error (slot, err);
4192+ slot.release ();
4193+ }
41484194 }
4195+
41494196 break ;
41504197 }
41514198 }
41524199
41534200 // retry with half the batch size to try to find a free slot in the KV cache
4154- n_batch /= 2 ;
4201+ if (!try_purge_idle_slots ()) {
4202+ n_batch /= 2 ;
4203+ }
41554204
41564205 SRV_WRN (" failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n " , i, n_batch, ret);
41574206
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
43914440 return 1 ;
43924441 }
43934442
4443+ // TODO: should we have a separate n_parallel parameter for the server?
4444+ // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
4445+ if (params.n_parallel == 1 && params.kv_unified == false ) {
4446+ LOG_WRN (" %s: setting n_parallel = 4 and kv_unified = true\n " , __func__);
4447+
4448+ params.n_parallel = 4 ;
4449+ params.kv_unified = true ;
4450+ }
4451+
43944452 common_init ();
43954453
43964454 // struct that contains llama context and inference
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
49445002 // Everything else, including multimodal completions.
49455003 inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
49465004 }
4947- const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server. params_base . n_parallel ;
5005+ const size_t n_ctx_slot = ctx_server.slots . front (). n_ctx ;
49485006 tasks.reserve (inputs.size ());
49495007 for (size_t i = 0 ; i < inputs.size (); i++) {
49505008 auto n_prompt_tokens = inputs[i].size ();
0 commit comments