Skip to content

Commit 42e9fe8

Browse files
committed
server : support unified context across slots
1 parent 10fcc41 commit 42e9fe8

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

tools/server/server.cpp

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2413,7 +2413,7 @@ struct server_context {
24132413

24142414
params_dft.devices = params_base.speculative.devices;
24152415
params_dft.model = params_base.speculative.model;
2416-
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
2416+
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? slots.front().n_ctx : params_base.speculative.n_ctx;
24172417
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
24182418
params_dft.n_parallel = 1;
24192419
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2501,7 +2501,7 @@ struct server_context {
25012501
}
25022502

25032503
void init() {
2504-
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
2504+
const int32_t n_ctx_slot = params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
25052505

25062506
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
25072507

@@ -2705,6 +2705,36 @@ struct server_context {
27052705
return ret;
27062706
}
27072707

2708+
// return true if at least one slot has been purged
2709+
// TODO: improve logic
2710+
// - smarter decision which slot to purge
2711+
// - move slot to level 2 cache instead of removing?
2712+
// - instead of purging, try to store and resume later?
2713+
bool try_purge_idle_slots() {
2714+
bool res = false;
2715+
2716+
if (!params_base.kv_unified) {
2717+
return res;
2718+
}
2719+
2720+
for (auto & slot : slots) {
2721+
if (slot.is_processing()) {
2722+
continue;
2723+
}
2724+
2725+
if (slot.prompt.n_tokens() > 0) {
2726+
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
2727+
2728+
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2729+
slot.prompt.tokens.clear();
2730+
2731+
res = true;
2732+
}
2733+
}
2734+
2735+
return res;
2736+
}
2737+
27082738
bool launch_slot_with_task(server_slot & slot, server_task && task) {
27092739
slot.reset();
27102740

@@ -3640,9 +3670,10 @@ struct server_context {
36403670
int32_t n_batch = llama_n_batch(ctx);
36413671
int32_t n_ubatch = llama_n_ubatch(ctx);
36423672

3643-
// next, batch any pending prompts without exceeding n_batch
3644-
float alora_scale = -1.0f;
3673+
float alora_scale = -1.0f;
36453674
size_t alora_disabled_id = 0;
3675+
3676+
// next, batch any pending prompts without exceeding n_batch
36463677
if (params_base.cont_batching || batch.n_tokens == 0) {
36473678
for (auto & slot : slots) {
36483679
// check if we can batch this slot with the previous one
@@ -4123,6 +4154,8 @@ struct server_context {
41234154
std::string err;
41244155

41254156
if (n_batch == 1 && ret == 1) {
4157+
// TODO: try to terminate only the largest active slot and continue
4158+
// need to remove the tokens from the current batch too
41264159
err = "Context size has been exceeded.";
41274160
}
41284161

@@ -4138,17 +4171,23 @@ struct server_context {
41384171
// TODO: handle ret == 2 (abort) when we start aborting
41394172

41404173
if (!err.empty()) {
4141-
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
4174+
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
4175+
41424176
for (auto & slot : slots) {
4143-
send_error(slot, err);
4144-
slot.release();
4177+
if (slot.is_processing()) {
4178+
send_error(slot, err);
4179+
slot.release();
4180+
}
41454181
}
4182+
41464183
break;
41474184
}
41484185
}
41494186

41504187
// retry with half the batch size to try to find a free slot in the KV cache
4151-
n_batch /= 2;
4188+
if (!try_purge_idle_slots()) {
4189+
n_batch /= 2;
4190+
}
41524191

41534192
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
41544193

@@ -4942,7 +4981,7 @@ int main(int argc, char ** argv) {
49424981
// Everything else, including multimodal completions.
49434982
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
49444983
}
4945-
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
4984+
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
49464985
tasks.reserve(inputs.size());
49474986
for (size_t i = 0; i < inputs.size(); i++) {
49484987
auto n_prompt_tokens = inputs[i].size();

0 commit comments

Comments
 (0)