Skip to content

Commit 5b2093b

Browse files
authored
server : handle context overflow during decode (#17267)
* server : handle context overflow during decode * server : minor refactor
1 parent 52e5d42 commit 5b2093b

File tree

1 file changed

+30
-29
lines changed

1 file changed

+30
-29
lines changed

tools/server/server.cpp

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,14 +1686,13 @@ struct server_slot {
16861686
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
16871687
}
16881688

1689-
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
1689+
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
16901690
bool res = prompt_cache.load(prompt, tokens, ctx, id);
16911691
if (!res) {
16921692
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
1693-
1694-
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
1695-
prompt.tokens.clear();
16961693
}
1694+
1695+
return res;
16971696
}
16981697

16991698
std::vector<common_adapter_lora_info> lora;
@@ -2339,7 +2338,6 @@ struct server_context {
23392338

23402339
llama_batch batch {};
23412340

2342-
bool clean_kv_cache = true;
23432341
bool add_bos_token = true;
23442342

23452343
int32_t n_ctx; // total context for all clients / slots
@@ -2702,7 +2700,10 @@ struct server_context {
27022700
const int64_t t_start = ggml_time_us();
27032701

27042702
ret->prompt_save(*prompt_cache);
2705-
ret->prompt_load(*prompt_cache, task.tokens);
2703+
2704+
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
2705+
clear_slot(*ret);
2706+
}
27062707

27072708
prompt_cache->update();
27082709

@@ -2713,12 +2714,21 @@ struct server_context {
27132714
return ret;
27142715
}
27152716

2716-
// return true if at least one slot has been purged
2717+
void clear_slot(server_slot & slot) const {
2718+
GGML_ASSERT(!slot.is_processing());
2719+
2720+
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
2721+
2722+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2723+
slot.prompt.tokens.clear();
2724+
}
2725+
2726+
// return true if at least one slot has been cleared
27172727
// TODO: improve logic
2718-
// - smarter decision which slot to purge (LRU or longest prompt?)
2728+
// - smarter decision which slot to clear (LRU or longest prompt?)
27192729
// - move slot to level 2 cache instead of removing?
27202730
// - instead of purging, try to store and resume later?
2721-
bool try_purge_idle_slots() {
2731+
bool try_clear_idle_slots() {
27222732
bool res = false;
27232733

27242734
if (!params_base.kv_unified) {
@@ -2733,12 +2743,11 @@ struct server_context {
27332743
if (slot.prompt.n_tokens() > 0) {
27342744
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
27352745

2736-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2737-
slot.prompt.tokens.clear();
2746+
clear_slot(slot);
27382747

27392748
res = true;
27402749

2741-
// purge slots one by one
2750+
// clear slots one by one
27422751
break;
27432752
}
27442753
}
@@ -2848,14 +2857,6 @@ struct server_context {
28482857
return true;
28492858
}
28502859

2851-
void kv_cache_clear() {
2852-
SRV_DBG("%s", "clearing KV cache\n");
2853-
2854-
// clear the entire KV cache
2855-
llama_memory_clear(llama_get_memory(ctx), true);
2856-
clean_kv_cache = false;
2857-
}
2858-
28592860
bool process_token(completion_token_output & result, server_slot & slot) {
28602861
// remember which tokens were sampled - used for repetition penalties during sampling
28612862
const std::string token_str = result.text_to_send;
@@ -3443,8 +3444,8 @@ struct server_context {
34433444

34443445
// Erase token cache
34453446
const size_t n_erased = slot->prompt.tokens.size();
3446-
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
3447-
slot->prompt.tokens.clear();
3447+
3448+
clear_slot(*slot);
34483449

34493450
auto res = std::make_unique<server_task_result_slot_erase>();
34503451
res->id = task.id;
@@ -3477,9 +3478,6 @@ struct server_context {
34773478

34783479
if (all_idle) {
34793480
SRV_INF("%s", "all slots are idle\n");
3480-
if (clean_kv_cache) {
3481-
kv_cache_clear();
3482-
}
34833481

34843482
return;
34853483
}
@@ -3873,12 +3871,11 @@ struct server_context {
38733871

38743872
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
38753873
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
3876-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
3874+
3875+
clear_slot(slot);
38773876

38783877
// there is no common part left
38793878
slot.n_prompt_tokens_cache = 0;
3880-
3881-
slot.prompt.tokens.clear();
38823879
}
38833880

38843881
// check if we should process the image
@@ -4108,6 +4105,10 @@ struct server_context {
41084105
if (slot.is_processing()) {
41094106
send_error(slot, err);
41104107
slot.release();
4108+
4109+
// note: it's complicated to keep track of how much of the current batch has been
4110+
// processed before the error occurred, so we simply clear the entire context
4111+
clear_slot(slot);
41114112
}
41124113
}
41134114

@@ -4116,7 +4117,7 @@ struct server_context {
41164117
}
41174118

41184119
// retry with half the batch size to try to find a free slot in the KV cache
4119-
if (!try_purge_idle_slots()) {
4120+
if (!try_clear_idle_slots()) {
41204121
n_batch /= 2;
41214122
}
41224123

0 commit comments

Comments
 (0)