Skip to content

Commit cfba346

Browse files
committed
generalize swa_checkpoint to ctx_checkpoint
this extends `llama-server`'s SWA checkpointing logic to include hybrid/recurrent models such as Jamba, Granite
1 parent 257d492 commit cfba346

File tree

1 file changed

+45
-47
lines changed

1 file changed

+45
-47
lines changed

tools/server/server.cpp

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ struct completion_token_output {
764764
}
765765
};
766766

767-
struct swa_checkpoint {
767+
struct ctx_checkpoint {
768768
llama_pos pos_min;
769769
llama_pos pos_max;
770770

@@ -1460,7 +1460,7 @@ struct server_slot {
14601460

14611461
std::vector<completion_token_output> generated_token_probs;
14621462

1463-
std::vector<swa_checkpoint> swa_checkpoints;
1463+
std::vector<ctx_checkpoint> ctx_checkpoints;
14641464

14651465
bool has_next_token = true;
14661466
bool has_new_line = false;
@@ -3555,38 +3555,38 @@ struct server_context {
35553555
if (pos_min > pos_min_thold) {
35563556
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
35573557

3558-
// search for a SWA checkpoint
3558+
// search for a context checkpoint
35593559
const auto it = std::find_if(
3560-
slot.swa_checkpoints.rbegin(),
3561-
slot.swa_checkpoints.rend(),
3560+
slot.ctx_checkpoints.rbegin(),
3561+
slot.ctx_checkpoints.rend(),
35623562
[&](const auto & cur) {
35633563
return cur.pos_min <= pos_min_thold;
35643564
}
35653565
);
35663566

3567-
bool do_reset = it == slot.swa_checkpoints.rend();
3567+
bool do_reset = it == slot.ctx_checkpoints.rend();
3568+
printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
35683569

35693570
if (!do_reset) {
3570-
// restore the checkpoint
3571-
const size_t swa_size = it->data.size();
3572-
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3571+
// restore the context checkpoint
3572+
const size_t ctx_checkpoint_size = it->data.size();
3573+
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
35733574

35743575
if (n != swa_size) {
3575-
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
3576+
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35763577
do_reset = true;
3578+
printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35773579
} else {
35783580
slot.n_past = std::min(slot.n_past, it->pos_max);
3579-
3580-
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
3581+
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35813582
}
35823583
}
35833584

35843585
if (do_reset) {
35853586
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
35863587
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3587-
35883588
slot.n_past = 0;
3589-
slot.swa_checkpoints.clear();
3589+
slot.ctx_checkpoints.clear();
35903590
}
35913591
}
35923592
}
@@ -3595,21 +3595,20 @@ struct server_context {
35953595
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
35963596

35973597
// erase any checkpoints with pos_min > pos_min_thold
3598-
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
3599-
const auto & cur = slot.swa_checkpoints[i];
3598+
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
3599+
const auto & cur = slot.ctx_checkpoints[i];
36003600
if (cur.pos_min > pos_min_thold) {
3601-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
3602-
3603-
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3601+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
3602+
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
36043603
}
36053604
}
36063605
}
36073606
}
36083607

36093608
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
3610-
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
3611-
3609+
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
36123610
slot.n_past--;
3611+
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
36133612
}
36143613

36153614
slot.n_prompt_tokens_cache = slot.n_past;
@@ -3623,17 +3622,18 @@ struct server_context {
36233622
}
36243623
}
36253624

3626-
// keep only the common part
3625+
// truncate any tokens that are beyond n_past for this slot
36273626
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
3628-
// could not partially delete (likely using a non-Transformer model)
3627+
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
36293628
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
36303629

36313630
// there is no common part left
36323631
slot.n_past = 0;
36333632
slot.n_prompt_tokens_cache = 0;
3633+
printf("[DEBUG] we had no choice but to truncate all tokens from this slot :( very sad");
36343634
}
36353635

3636-
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
3636+
SLT_INF(slot, "n_past = %d\n", slot.n_past);
36373637

36383638
// remove the non-common part from the cache
36393639
slot.cache_tokens.keep_first(slot.n_past);
@@ -3854,37 +3854,35 @@ struct server_context {
38543854
// prompt evaluated for next-token prediction
38553855
slot.state = SLOT_STATE_GENERATING;
38563856

3857-
// make a checkpoint with the SWA memory
3858-
// checkpoints are needed only if we are not using "--swa-full"
3859-
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
3860-
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
3861-
{
3862-
const auto & cur = slot.swa_checkpoints.back();
3863-
3864-
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
3865-
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3866-
}
3867-
3868-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
3857+
// make a checkpoint of the parts of memory that cannot be rolled back.
3858+
// checkpoints are needed only if:
3859+
// - the model uses SWA and we are not using `swa_full`
3860+
// - the model architecture is marked as recurrent or hybrid
3861+
bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
3862+
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
3863+
3864+
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
3865+
if (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
3866+
// make room for the new checkpoint, if needed
3867+
const auto & cur = slot.ctx_checkpoints.back();
3868+
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3869+
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3870+
3871+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
38693872
}
38703873

3871-
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3874+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
38723875

3873-
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
3876+
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
38743877
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
38753878
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
3876-
/*.data = */ std::vector<uint8_t>(swa_size),
3879+
/*.data = */ std::vector<uint8_t>(checkpoint_size),
38773880
});
38783881

3879-
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880-
3881-
float size_total = 0.0f;
3882-
for (const auto & checkpoint : slot.swa_checkpoints) {
3883-
size_total += (float) checkpoint.data.size() / 1024 / 1024;
3884-
}
3882+
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
38853883

3886-
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
3887-
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
3884+
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3885+
slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) );
38883886
}
38893887
} else if (slot.state != SLOT_STATE_GENERATING) {
38903888
continue; // continue loop of slots

0 commit comments

Comments
 (0)