@@ -2413,7 +2413,7 @@ struct server_context {
24132413
24142414 params_dft.devices = params_base.speculative .devices ;
24152415 params_dft.model = params_base.speculative .model ;
2416- params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? params_base. n_ctx / params_base. n_parallel : params_base.speculative .n_ctx ;
2416+ params_dft.n_ctx = params_base.speculative .n_ctx == 0 ? slots. front (). n_ctx : params_base.speculative .n_ctx ;
24172417 params_dft.n_gpu_layers = params_base.speculative .n_gpu_layers ;
24182418 params_dft.n_parallel = 1 ;
24192419 params_dft.cache_type_k = params_base.speculative .cache_type_k ;
@@ -2501,7 +2501,7 @@ struct server_context {
25012501 }
25022502
25032503 void init () {
2504- const int32_t n_ctx_slot = n_ctx / params_base.n_parallel ;
2504+ const int32_t n_ctx_slot = params_base. kv_unified ? n_ctx : n_ctx / params_base.n_parallel ;
25052505
25062506 SRV_INF (" initializing slots, n_slots = %d\n " , params_base.n_parallel );
25072507
@@ -2705,6 +2705,36 @@ struct server_context {
27052705 return ret;
27062706 }
27072707
2708+ // return true if at least one slot has been purged
2709+ // TODO: improve logic
2710+ // - smarter decision which slot to purge
2711+ // - move slot to level 2 cache instead of removing?
2712+ // - instead of purging, try to store and resume later?
2713+ bool try_purge_idle_slots () {
2714+ bool res = false ;
2715+
2716+ if (!params_base.kv_unified ) {
2717+ return res;
2718+ }
2719+
2720+ for (auto & slot : slots) {
2721+ if (slot.is_processing ()) {
2722+ continue ;
2723+ }
2724+
2725+ if (slot.prompt .n_tokens () > 0 ) {
2726+ SRV_WRN (" purging slot %d with %zu tokens\n " , slot.id , slot.prompt .tokens .size ());
2727+
2728+ llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
2729+ slot.prompt .tokens .clear ();
2730+
2731+ res = true ;
2732+ }
2733+ }
2734+
2735+ return res;
2736+ }
2737+
27082738 bool launch_slot_with_task (server_slot & slot, server_task && task) {
27092739 slot.reset ();
27102740
@@ -3661,9 +3691,10 @@ struct server_context {
36613691 int32_t n_batch = llama_n_batch (ctx);
36623692 int32_t n_ubatch = llama_n_ubatch (ctx);
36633693
3664- // next, batch any pending prompts without exceeding n_batch
3665- float alora_scale = -1 .0f ;
3694+ float alora_scale = -1 .0f ;
36663695 size_t alora_disabled_id = 0 ;
3696+
3697+ // next, batch any pending prompts without exceeding n_batch
36673698 if (params_base.cont_batching || batch.n_tokens == 0 ) {
36683699 for (auto & slot : slots) {
36693700 // check if we can batch this slot with the previous one
@@ -4144,6 +4175,8 @@ struct server_context {
41444175 std::string err;
41454176
41464177 if (n_batch == 1 && ret == 1 ) {
4178+ // TODO: try to terminate only the largest active slot and continue
4179+ // need to remove the tokens from the current batch too
41474180 err = " Context size has been exceeded." ;
41484181 }
41494182
@@ -4159,17 +4192,23 @@ struct server_context {
41594192 // TODO: handle ret == 2 (abort) when we start aborting
41604193
41614194 if (!err.empty ()) {
4162- SRV_ERR (" %s, i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4195+ SRV_ERR (" %s i = %d, n_batch = %d, ret = %d\n " , err.c_str (), i, n_batch, ret);
4196+
41634197 for (auto & slot : slots) {
4164- send_error (slot, err);
4165- slot.release ();
4198+ if (slot.is_processing ()) {
4199+ send_error (slot, err);
4200+ slot.release ();
4201+ }
41664202 }
4203+
41674204 break ;
41684205 }
41694206 }
41704207
41714208 // retry with half the batch size to try to find a free slot in the KV cache
4172- n_batch /= 2 ;
4209+ if (!try_purge_idle_slots ()) {
4210+ n_batch /= 2 ;
4211+ }
41734212
41744213 SRV_WRN (" failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n " , i, n_batch, ret);
41754214
@@ -4963,7 +5002,7 @@ int main(int argc, char ** argv) {
49635002 // Everything else, including multimodal completions.
49645003 inputs = tokenize_input_prompts (ctx_server.vocab , ctx_server.mctx , prompt, true , true );
49655004 }
4966- const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server. params_base . n_parallel ;
5005+ const size_t n_ctx_slot = ctx_server.slots . front (). n_ctx ;
49675006 tasks.reserve (inputs.size ());
49685007 for (size_t i = 0 ; i < inputs.size (); i++) {
49695008 auto n_prompt_tokens = inputs[i].size ();
0 commit comments