@@ -2413,7 +2413,7 @@ struct server_context {
24132413
24142414 params_dft.devices = params_base.speculative .devices ;
24152415 params_dft.model = params_base.speculative .model ;
2416- params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? params_base. n_ctx / params_base. n_parallel : params_base.speculative .n_ctx ;
2416+ params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? slots. front (). n_ctx : params_base.speculative .n_ctx ;
24172417 params_dft.n_gpu_layers = params_base.speculative .n_gpu_layers ;
24182418 params_dft.n_parallel = 1 ;
24192419 params_dft.cache_type_k = params_base.speculative .cache_type_k ;
@@ -2501,7 +2501,7 @@ struct server_context {
25012501 }
25022502
25032503 void init () {
2504- const int32_t n_ctx_slot = n_ctx / params_base.n_parallel ;
2504+ const int32_t n_ctx_slot = params_base. kv_unified ? n_ctx : n_ctx / params_base.n_parallel ;
25052505
25062506 SRV_INF (" initializing slots, n_slots = %d\n " , params_base.n_parallel );
25072507
@@ -2705,6 +2705,36 @@ struct server_context {
27052705 return ret;
27062706 }
27072707
2708+ // return true if at least one slot has been purged
2709+ // TODO: improve logic
2710+ // - smarter decision which slot to purge
2711+ // - move slot to level 2 cache instead of removing?
2712+ // - instead of purging, try to store and resume later?
2713+ bool try_purge_idle_slots () {
2714+ bool res = false ;
2715+
2716+ if (!params_base.kv_unified ) {
2717+ return res;
2718+ }
2719+
2720+ for (auto & slot : slots) {
2721+ if (slot.is_processing ()) {
2722+ continue ;
2723+ }
2724+
2725+ if (slot.prompt .n_tokens () > 0 ) {
2726+ SRV_WRN (" purging slot %d with %zu tokens\n " , slot.id , slot.prompt .tokens .size ());
2727+
2728+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
2729+ slot.prompt .tokens .clear ();
2730+
2731+ res = true ;
2732+ }
2733+ }
2734+
2735+ return res;
2736+ }
2737+
27082738 bool launch_slot_with_task (server_slot & slot, server_task && task) {
27092739 slot.reset ();
27102740
@@ -3640,9 +3670,10 @@ struct server_context {
36403670 int32_t n_batch = llama_n_batch (ctx);
36413671 int32_t n_ubatch = llama_n_ubatch (ctx);
36423672
3643- // next, batch any pending prompts without exceeding n_batch
3644- float alora_scale = -1 .0f ;
3673+ float alora_scale = -1 .0f ;
36453674 size_t alora_disabled_id = 0 ;
3675+
3676+ // next, batch any pending prompts without exceeding n_batch
36463677 if (params_base.cont_batching || batch.n_tokens == 0 ) {
36473678 for (auto & slot : slots) {
36483679 // check if we can batch this slot with the previous one
@@ -4123,6 +4154,8 @@ struct server_context {
41234154 std::string err;
41244155
41254156 if (n_batch == 1 && ret == 1 ) {
4157+ // TODO: try to terminate only the largest active slot and continue
4158+ // need to remove the tokens from the current batch too
41264159 err = " Context size has been exceeded." ;
41274160 }
41284161
@@ -4138,17 +4171,23 @@ struct server_context {
41384171 // TODO: handle ret == 2 (abort) when we start aborting
41394172
41404173 if (!err.empty ()) {
4141- SRV_ERR (" %s, i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4174+ SRV_ERR (" %s i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4175+
41424176 for (auto & slot : slots) {
4143- send_error (slot, err);
4144- slot.release ();
4177+ if (slot.is_processing ()) {
4178+ send_error (slot, err);
4179+ slot.release ();
4180+ }
41454181 }
4182+
41464183 break ;
41474184 }
41484185 }
41494186
41504187 // retry with half the batch size to try to find a free slot in the KV cache
4151- n_batch /= 2 ;
4188+ if (!try_purge_idle_slots ()) {
4189+ n_batch /= 2 ;
4190+ }
41524191
41534192 SRV_WRN (" failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n " , i, n_batch, ret);
41544193
@@ -4942,7 +4981,7 @@ int main(int argc, char ** argv) {
49424981 // Everything else, including multimodal completions.
49434982 inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
49444983 }
4945- const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server. params_base . n_parallel ;
4984+ const size_t n_ctx_slot = ctx_server.slots . front (). n_ctx ;
49464985 tasks.reserve (inputs.size ());
49474986 for (size_t i = 0 ; i < inputs.size (); i++) {
49484987 auto n_prompt_tokens = inputs[i].size ();
0 commit comments