-
Notifications
You must be signed in to change notification settings - Fork 13.5k
server : support unified cache across slots #16736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
57ece5b
a42fb77
492f628
8222e9c
2179175
f0f105f
e7b7cbf
290f6a9
23323cd
f2cca02
ff68436
c08d0d1
356dc08
56fceee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2407,7 +2407,7 @@ struct server_context { | |
|
|
||
| params_dft.devices = params_base.speculative.devices; | ||
| params_dft.model = params_base.speculative.model; | ||
| params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; | ||
| params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; | ||
| params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; | ||
| params_dft.n_parallel = 1; | ||
| params_dft.cache_type_k = params_base.speculative.cache_type_k; | ||
|
|
@@ -2495,10 +2495,16 @@ struct server_context { | |
| } | ||
|
|
||
| void init() { | ||
| const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; | ||
|
|
||
| SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); | ||
|
|
||
| const int n_ctx_train = llama_model_n_ctx_train(model); | ||
|
|
||
| int n_ctx_slot = llama_n_ctx_seq(ctx); | ||
| if (n_ctx_slot > n_ctx_train) { | ||
| SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); | ||
| n_ctx_slot = n_ctx_train; | ||
| } | ||
|
|
||
| for (int i = 0; i < params_base.n_parallel; i++) { | ||
| server_slot slot; | ||
|
|
||
|
|
@@ -2527,7 +2533,7 @@ struct server_context { | |
| } | ||
| } | ||
|
|
||
| SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); | ||
| SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); | ||
|
|
||
| slot.callback_on_release = [this](int) { | ||
| queue_tasks.pop_deferred_task(); | ||
|
|
@@ -2699,6 +2705,39 @@ struct server_context { | |
| return ret; | ||
| } | ||
|
|
||
| // return true if at least one slot has been purged | ||
| // TODO: improve logic | ||
| // - smarter decision which slot to purge (LRU or longest prompt?) | ||
| // - move slot to level 2 cache instead of removing? | ||
| // - instead of purging, try to store and resume later? | ||
| bool try_purge_idle_slots() { | ||
| bool res = false; | ||
|
|
||
| if (!params_base.kv_unified) { | ||
| return res; | ||
| } | ||
|
|
||
| for (auto & slot : slots) { | ||
| if (slot.is_processing()) { | ||
| continue; | ||
| } | ||
|
|
||
| if (slot.prompt.n_tokens() > 0) { | ||
| SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); | ||
|
|
||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
| slot.prompt.tokens.clear(); | ||
|
|
||
| res = true; | ||
|
|
||
| // purge slots one by one | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| return res; | ||
| } | ||
|
|
||
| bool launch_slot_with_task(server_slot & slot, server_task && task) { | ||
| slot.reset(); | ||
|
|
||
|
|
@@ -3635,9 +3674,10 @@ struct server_context { | |
| int32_t n_batch = llama_n_batch(ctx); | ||
| int32_t n_ubatch = llama_n_ubatch(ctx); | ||
|
|
||
| // next, batch any pending prompts without exceeding n_batch | ||
| float alora_scale = -1.0f; | ||
| float alora_scale = -1.0f; | ||
| size_t alora_disabled_id = 0; | ||
|
|
||
| // next, batch any pending prompts without exceeding n_batch | ||
| if (params_base.cont_batching || batch.n_tokens == 0) { | ||
| for (auto & slot : slots) { | ||
| // check if we can batch this slot with the previous one | ||
|
|
@@ -3914,8 +3954,11 @@ struct server_context { | |
|
|
||
| // truncate any tokens that are beyond n_past for this slot | ||
| const llama_pos p0 = slot.prompt.tokens.pos_next(); | ||
|
|
||
| SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); | ||
|
|
||
| if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { | ||
| SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0); | ||
| SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); | ||
| llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); | ||
|
|
||
| // there is no common part left | ||
|
|
@@ -3924,8 +3967,6 @@ struct server_context { | |
| slot.prompt.tokens.clear(); | ||
| } | ||
|
|
||
| SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); | ||
|
|
||
| // check if we should process the image | ||
| if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { | ||
| // process the image | ||
|
|
@@ -4126,6 +4167,8 @@ struct server_context { | |
| std::string err; | ||
|
|
||
| if (n_batch == 1 && ret == 1) { | ||
| // TODO: try to terminate only the largest active slot/sequence and continue with the rest | ||
| // need to remove the tokens from the current batch too | ||
| err = "Context size has been exceeded."; | ||
| } | ||
|
|
||
|
|
@@ -4141,17 +4184,23 @@ struct server_context { | |
| // TODO: handle ret == 2 (abort) when we start aborting | ||
|
|
||
| if (!err.empty()) { | ||
| SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); | ||
| SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); | ||
|
|
||
| for (auto & slot : slots) { | ||
| send_error(slot, err); | ||
| slot.release(); | ||
| if (slot.is_processing()) { | ||
| send_error(slot, err); | ||
| slot.release(); | ||
| } | ||
| } | ||
|
|
||
| break; | ||
| } | ||
| } | ||
|
|
||
| // retry with half the batch size to try to find a free slot in the KV cache | ||
| n_batch /= 2; | ||
| if (!try_purge_idle_slots()) { | ||
| n_batch /= 2; | ||
| } | ||
|
|
||
| SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); | ||
|
|
||
|
|
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) { | |
| return 1; | ||
| } | ||
|
|
||
| // TODO: should we have a separate n_parallel parameter for the server? | ||
| // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 | ||
| if (params.n_parallel == 1 && params.kv_unified == false) { | ||
| LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); | ||
|
|
||
| params.n_parallel = 4; | ||
| params.kv_unified = true; | ||
| } | ||
|
Comment on lines
+4445
to
+4450
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why this can't be default params in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll see if I can make it the default - I thought that some of the examples might not like it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm yeah I didn't notice that there are multiple example all using In this case, maybe we can use a dedicated variable for server, like This can be useful when auto-generating the documentation for server args |
||
|
|
||
| common_init(); | ||
|
|
||
| // struct that contains llama context and inference | ||
|
|
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) { | |
| // Everything else, including multimodal completions. | ||
| inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); | ||
| } | ||
| const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; | ||
| const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; | ||
| tasks.reserve(inputs.size()); | ||
| for (size_t i = 0; i < inputs.size(); i++) { | ||
| auto n_prompt_tokens = inputs[i].size(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this warning should be moved inside the
ifcondition above, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe I forgot this from a discussion before, but currently in which case we need to retry with a small batch size?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main case for retrying with smaller batches was back when we didn't have
ggml_set_rowsand we always had to search for contiguous set of cells (KV slots) inside the cache buffer to place the input batch. Now withggml_set_rowsthis is no longer needed and technically, retrying with a smaller batch size almost has almost no purpose except in some rare cases.But generally, when
llama_decodereturns 1, you should retry with a smaller batch.