Skip to content

Commit 96db966

Browse files
committed
server : add SWA checkpoints
ggml-ci
1 parent c24f4e2 commit 96db966

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
19571957
for (const auto & layer : layers) {
19581958
const uint32_t il = layer.il;
19591959

1960+
if (!hparams.is_swa(il)) {
1961+
continue;
1962+
}
1963+
19601964
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
19611965

19621966
auto * k = layer.k_stream[cr.strm];
@@ -1981,6 +1985,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
19811985
for (const auto & layer : layers) {
19821986
const uint32_t il = layer.il;
19831987

1988+
if (!hparams.is_swa(il)) {
1989+
continue;
1990+
}
1991+
19841992
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
19851993

19861994
auto * v = layer.v_stream[cr.strm];
@@ -2007,6 +2015,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
20072015
for (const auto & layer : layers) {
20082016
const uint32_t il = layer.il;
20092017

2018+
if (!hparams.is_swa(il)) {
2019+
continue;
2020+
}
2021+
20102022
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
20112023

20122024
auto * v = layer.v_stream[cr.strm];
@@ -2162,6 +2174,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
21622174
for (const auto & layer : layers) {
21632175
const uint32_t il = layer.il;
21642176

2177+
if (!hparams.is_swa(il)) {
2178+
continue;
2179+
}
2180+
21652181
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
21662182

21672183
auto * k = layer.k_stream[strm];
@@ -2194,6 +2210,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
21942210
for (const auto & layer : layers) {
21952211
const uint32_t il = layer.il;
21962212

2213+
if (!hparams.is_swa(il)) {
2214+
continue;
2215+
}
2216+
21972217
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
21982218

21992219
auto * v = layer.v_stream[strm];
@@ -2226,6 +2246,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
22262246
for (const auto & layer : layers) {
22272247
const uint32_t il = layer.il;
22282248

2249+
if (!hparams.is_swa(il)) {
2250+
continue;
2251+
}
2252+
22292253
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
22302254

22312255
auto * v = layer.v_stream[strm];

tools/server/server.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,13 @@ struct completion_token_output {
692692
}
693693
};
694694

695+
struct swa_checkpoint {
696+
std::vector<uint8_t> data;
697+
698+
llama_pos pos_min;
699+
llama_pos pos_max;
700+
};
701+
695702
struct server_task_result_cmpl_final : server_task_result {
696703
int index = 0;
697704

@@ -1336,6 +1343,8 @@ struct server_slot {
13361343

13371344
std::vector<completion_token_output> generated_token_probs;
13381345

1346+
std::vector<swa_checkpoint> swa_checkpoints;
1347+
13391348
bool has_next_token = true;
13401349
bool has_new_line = false;
13411350
bool truncated = false;
@@ -3300,10 +3309,42 @@ struct server_context {
33003309

33013310
const auto n_swa = llama_model_n_swa(model);
33023311
if (pos_min > std::max(0, slot.n_past - n_swa)) {
3303-
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
3304-
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
3305-
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3306-
slot.n_past = 0;
3312+
// search for a SWA checkpoint
3313+
int ic = -1;
3314+
int np = std::numeric_limits<int>::max();
3315+
for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) {
3316+
const auto & cur = slot.swa_checkpoints[i];
3317+
if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) {
3318+
const int p = std::max(0, slot.n_past - cur.pos_max);
3319+
3320+
if (p < np) {
3321+
ic = i;
3322+
np = p;
3323+
}
3324+
}
3325+
}
3326+
3327+
if (ic == -1) {
3328+
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
3329+
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
3330+
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3331+
slot.n_past = 0;
3332+
3333+
slot.swa_checkpoints.clear();
3334+
} else {
3335+
// erase all checkpoints after the one we are using
3336+
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end());
3337+
3338+
// restore the checkpoint
3339+
const auto & cur = slot.swa_checkpoints[ic];
3340+
3341+
const size_t swa_size = cur.data.size();
3342+
llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id);
3343+
3344+
slot.n_past = std::min(slot.n_past, cur.pos_max);
3345+
3346+
SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
3347+
}
33073348
}
33083349
}
33093350
}
@@ -3517,6 +3558,25 @@ struct server_context {
35173558

35183559
// prompt evaluated for next-token prediction
35193560
slot.state = SLOT_STATE_GENERATING;
3561+
3562+
// make a checkpoint
3563+
if (llama_model_n_swa(model) > 0) {
3564+
if (slot.swa_checkpoints.size() > 8) {
3565+
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
3566+
}
3567+
3568+
auto & cur = slot.swa_checkpoints.emplace_back();
3569+
3570+
cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
3571+
cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
3572+
3573+
const size_t swa_size = llama_state_seq_get_size(ctx, slot.id);
3574+
cur.data.resize(swa_size);
3575+
3576+
llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id);
3577+
3578+
SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
3579+
}
35203580
} else if (slot.state != SLOT_STATE_GENERATING) {
35213581
continue; // continue loop of slots
35223582
}

0 commit comments

Comments
 (0)