- 
                Notifications
    
You must be signed in to change notification settings  - Fork 13.5k
 
server : support unified cache across slots #16736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
57ece5b
              a42fb77
              492f628
              8222e9c
              2179175
              f0f105f
              e7b7cbf
              290f6a9
              23323cd
              f2cca02
              ff68436
              c08d0d1
              356dc08
              56fceee
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -112,11 +112,16 @@ llama_context::llama_context( | |
| } | ||
| } | ||
| 
     | 
||
| const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; | ||
| if (cparams.kv_unified) { | ||
| cparams.n_ctx_seq = cparams.n_ctx; | ||
| } else { | ||
| cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; | ||
| cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; | ||
| } | ||
| 
     | 
||
| LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); | ||
| LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); | ||
| LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq); | ||
| LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); | ||
| LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); | ||
| LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); | ||
| LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); | ||
| 
        
          
        
         | 
    @@ -125,14 +130,14 @@ llama_context::llama_context( | |
| LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); | ||
| LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); | ||
| 
     | 
||
| if (n_ctx_per_seq < hparams.n_ctx_train) { | ||
| LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", | ||
| __func__, n_ctx_per_seq, hparams.n_ctx_train); | ||
| if (cparams.n_ctx_seq < hparams.n_ctx_train) { | ||
| LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", | ||
| __func__, cparams.n_ctx_seq, hparams.n_ctx_train); | ||
| } | ||
| 
     | 
||
| if (n_ctx_per_seq > hparams.n_ctx_train) { | ||
| LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", | ||
| __func__, n_ctx_per_seq, hparams.n_ctx_train); | ||
| if (cparams.n_ctx_seq > hparams.n_ctx_train) { | ||
| LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", | ||
| __func__, cparams.n_ctx_seq, hparams.n_ctx_train); | ||
| 
         
      Comment on lines
    
      +146
     to 
      +148
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This branch should not be reached due to the capping above on line 117. But keeping it in case the capping logic gets changed in the future.  | 
||
| } | ||
| 
     | 
||
| if (!hparams.vocab_only) { | ||
| 
          
            
          
           | 
    @@ -453,8 +458,8 @@ uint32_t llama_context::n_ctx() const { | |
| return cparams.n_ctx; | ||
| } | ||
| 
     | 
||
| uint32_t llama_context::n_ctx_per_seq() const { | ||
| return cparams.n_ctx / cparams.n_seq_max; | ||
| uint32_t llama_context::n_ctx_seq() const { | ||
| return cparams.n_ctx_seq; | ||
| } | ||
| 
     | 
||
| uint32_t llama_context::n_batch() const { | ||
| 
          
            
          
           | 
    @@ -2383,6 +2388,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) { | |
| return ctx->n_ctx(); | ||
| } | ||
| 
     | 
||
| uint32_t llama_n_ctx_seq(const llama_context * ctx) { | ||
| return ctx->n_ctx_seq(); | ||
| } | ||
| 
     | 
||
| uint32_t llama_n_batch(const llama_context * ctx) { | ||
| return ctx->n_batch(); | ||
| } | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -2407,7 +2407,7 @@ struct server_context { | |
| 
     | 
||
| params_dft.devices = params_base.speculative.devices; | ||
| params_dft.model = params_base.speculative.model; | ||
| params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; | ||
| params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; | ||
| params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; | ||
| params_dft.n_parallel = 1; | ||
| params_dft.cache_type_k = params_base.speculative.cache_type_k; | ||
| 
          
            
          
           | 
    @@ -2495,10 +2495,16 @@ struct server_context { | |
| } | ||
| 
     | 
||
| void init() { | ||
| const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; | ||
| 
     | 
||
| SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); | ||
| 
     | 
||
| const int n_ctx_train = llama_model_n_ctx_train(model); | ||
| 
     | 
||
| int n_ctx_slot = llama_n_ctx_seq(ctx); | ||
| if (n_ctx_slot > n_ctx_train) { | ||
| SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); | ||
| n_ctx_slot = n_ctx_train; | ||
| } | ||
| 
     | 
||
| for (int i = 0; i < params_base.n_parallel; i++) { | ||
| server_slot slot; | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -2527,7 +2533,7 @@ struct server_context { | |
| } | ||
| } | ||
| 
     | 
||
| SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); | ||
| SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); | ||
| 
     | 
||
| slot.callback_on_release = [this](int) { | ||
| queue_tasks.pop_deferred_task(); | ||
| 
          
            
          
           | 
    @@ -2699,6 +2705,39 @@ struct server_context { | |
| return ret; | ||
| } | ||
| 
     | 
||
| // return true if at least one slot has been purged | ||
| // TODO: improve logic | ||
| // - smarter decision which slot to purge (LRU or longest prompt?) | ||
| // - move slot to level 2 cache instead of removing? | ||
| // - instead of purging, try to store and resume later? | ||
| bool try_purge_idle_slots() { | ||
| bool res = false; | ||
| 
     | 
||
| if (!params_base.kv_unified) { | ||
| return res; | ||
| } | ||
| 
     | 
||
| for (auto & slot : slots) { | ||
| if (slot.is_processing()) { | ||
| continue; | ||
| } | ||
| 
     | 
||
| if (slot.prompt.n_tokens() > 0) { | ||
| SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); | ||
| 
     | 
||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
| slot.prompt.tokens.clear(); | ||
| 
     | 
||
| res = true; | ||
| 
     | 
||
| // purge slots one by one | ||
| break; | ||
| } | ||
| } | ||
| 
     | 
||
| return res; | ||
| } | ||
| 
     | 
||
| bool launch_slot_with_task(server_slot & slot, server_task && task) { | ||
| slot.reset(); | ||
| 
     | 
||
| 
          
            
          
           | 
    @@ -3635,9 +3674,10 @@ struct server_context { | |
| int32_t n_batch = llama_n_batch(ctx); | ||
| int32_t n_ubatch = llama_n_ubatch(ctx); | ||
| 
     | 
||
| // next, batch any pending prompts without exceeding n_batch | ||
| float alora_scale = -1.0f; | ||
| float alora_scale = -1.0f; | ||
| size_t alora_disabled_id = 0; | ||
| 
     | 
||
| // next, batch any pending prompts without exceeding n_batch | ||
| if (params_base.cont_batching || batch.n_tokens == 0) { | ||
| for (auto & slot : slots) { | ||
| // check if we can batch this slot with the previous one | ||
| 
          
            
          
           | 
    @@ -3914,8 +3954,11 @@ struct server_context { | |
| 
     | 
||
| // truncate any tokens that are beyond n_past for this slot | ||
| const llama_pos p0 = slot.prompt.tokens.pos_next(); | ||
| 
     | 
||
| SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); | ||
| 
     | 
||
| if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { | ||
| SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0); | ||
| SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); | ||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
| 
     | 
||
| // there is no common part left | ||
| 
        
          
        
         | 
    @@ -3924,8 +3967,6 @@ struct server_context { | |
| slot.prompt.tokens.clear(); | ||
| } | ||
| 
     | 
||
| SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); | ||
| 
     | 
||
| // check if we should process the image | ||
| if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { | ||
| // process the image | ||
| 
          
            
          
           | 
    @@ -4126,6 +4167,8 @@ struct server_context { | |
| std::string err; | ||
| 
     | 
||
| if (n_batch == 1 && ret == 1) { | ||
| // TODO: try to terminate only the largest active slot/sequence and continue with the rest | ||
| // need to remove the tokens from the current batch too | ||
| err = "Context size has been exceeded."; | ||
| } | ||
| 
     | 
||
| 
        
          
        
         | 
    @@ -4141,17 +4184,23 @@ struct server_context { | |
| // TODO: handle ret == 2 (abort) when we start aborting | ||
| 
     | 
||
| if (!err.empty()) { | ||
| SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); | ||
| SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); | ||
| 
     | 
||
| for (auto & slot : slots) { | ||
| send_error(slot, err); | ||
| slot.release(); | ||
| if (slot.is_processing()) { | ||
| send_error(slot, err); | ||
| slot.release(); | ||
| } | ||
| } | ||
| 
     | 
||
| break; | ||
| } | ||
| } | ||
| 
     | 
||
| // retry with half the batch size to try to find a free slot in the KV cache | ||
| n_batch /= 2; | ||
| if (!try_purge_idle_slots()) { | ||
| n_batch /= 2; | ||
| } | ||
| 
     | 
||
| SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this warning should be moved inside the  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also maybe I forgot this from a discussion before, but currently in which case we need to retry with a small batch size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main case for retrying with smaller batches was back when we didn't have  But generally, when   | 
||
| 
     | 
||
| 
          
            
          
           | 
    @@ -4391,6 +4440,13 @@ int main(int argc, char ** argv) { | |
| return 1; | ||
| } | ||
| 
     | 
||
| if (params.n_parallel == 1 && params.kv_unified == false) { | ||
| LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); | ||
| 
     | 
||
| params.n_parallel = 4; | ||
| params.kv_unified = true; | ||
| } | ||
| 
         
      Comment on lines
    
      +4445
     to 
      +4450
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why this can't be default params in  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll see if I can make it the default - I thought that some of the examples might not like it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah I didn't notice that there are multiple example all using  In this case, maybe we can use a dedicated variable for server, like  This can be useful when auto-generating the documentation for server args  | 
||
| 
     | 
||
| common_init(); | ||
| 
     | 
||
| // struct that contains llama context and inference | ||
| 
          
            
          
           | 
    @@ -4944,7 +5000,7 @@ int main(int argc, char ** argv) { | |
| // Everything else, including multimodal completions. | ||
| inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); | ||
| } | ||
| const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; | ||
| const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; | ||
| tasks.reserve(inputs.size()); | ||
| for (size_t i = 0; i < inputs.size(); i++) { | ||
| auto n_prompt_tokens = inputs[i].size(); | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe an error could be returned here if
n_ctxis not a multiple ofn_seq_max, since that's likely to be a mistake.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a warning. The problem that I see with throwing an error is that the user might often want to use the default training context for example split among 3 sequences. And in the majority of cases the training context typically a power of 2 would not be divisible by 3, resulting in an error.