Skip to content

Commit 7fdd16b

Browse files
authored
server : improve context checkpoint logic (ggml-org#16440)
1 parent 74b8fc1 commit 7fdd16b

File tree

2 files changed

+56
-35
lines changed

2 files changed

+56
-35
lines changed

src/llama-memory-recurrent.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
861861
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
862862
if (dest_seq_id != -1) {
863863
// single sequence
864-
865864
seq_rm(dest_seq_id, -1, -1);
866865

866+
if (cell_count == 0) {
867+
return true;
868+
}
869+
867870
llama_batch_allocr balloc(hparams.n_pos_per_embd());
868871

869872
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);

tools/server/server.cpp

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3676,6 +3676,20 @@ struct server_context {
36763676
alora_disabled_id = enabled_loras[0];
36773677
}
36783678

3679+
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
3680+
3681+
// make a checkpoint of the parts of the memory that cannot be rolled back.
3682+
// checkpoints are created only if:
3683+
// - the model uses SWA and we are not using `swa_full`
3684+
// - the model architecture is marked as recurrent or hybrid
3685+
//
3686+
// TODO: try to make this conditional on the context or the memory module, instead of the model type
3687+
do_checkpoint = do_checkpoint && (
3688+
llama_model_is_recurrent(model) ||
3689+
llama_model_is_hybrid(model) ||
3690+
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
3691+
);
3692+
36793693
// add prompt tokens for processing in the current batch
36803694
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
36813695
// get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
37003714

37013715
slot.n_prompt_tokens_processed++;
37023716
slot.n_past++;
3717+
3718+
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
3719+
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
3720+
break;
3721+
}
37033722
}
37043723

37053724
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
37303749
slot.i_batch = batch.n_tokens - 1;
37313750

37323751
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
3752+
3753+
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
3754+
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
3755+
3756+
// no need for empty or small checkpoints
3757+
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
3758+
3759+
// no need to create checkpoints that are too close together
3760+
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
3761+
3762+
if (do_checkpoint) {
3763+
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
3764+
// make room for the new checkpoint, if needed
3765+
const auto & cur = slot.ctx_checkpoints.front();
3766+
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3767+
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3768+
3769+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
3770+
}
3771+
3772+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3773+
3774+
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
3775+
/*.pos_min = */ pos_min,
3776+
/*.pos_max = */ pos_max,
3777+
/*.data = */ std::vector<uint8_t>(checkpoint_size),
3778+
});
3779+
3780+
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3781+
3782+
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3783+
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3784+
}
37333785
}
37343786
}
37353787

@@ -3853,40 +3905,6 @@ struct server_context {
38533905

38543906
// prompt evaluated for next-token prediction
38553907
slot.state = SLOT_STATE_GENERATING;
3856-
3857-
// make a checkpoint of the parts of the memory that cannot be rolled back.
3858-
// checkpoints are created only if:
3859-
// - the model uses SWA and we are not using `swa_full`
3860-
// - the model architecture is marked as recurrent or hybrid
3861-
//
3862-
// TODO: try to make this conditional on the context or the memory module, instead of the model type
3863-
const bool do_checkpoint =
3864-
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
3865-
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
3866-
3867-
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
3868-
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
3869-
// make room for the new checkpoint, if needed
3870-
const auto & cur = slot.ctx_checkpoints.front();
3871-
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3872-
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3873-
3874-
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
3875-
}
3876-
3877-
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3878-
3879-
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
3880-
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
3881-
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
3882-
/*.data = */ std::vector<uint8_t>(checkpoint_size),
3883-
});
3884-
3885-
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3886-
3887-
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3888-
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3889-
}
38903908
} else if (slot.state != SLOT_STATE_GENERATING) {
38913909
continue; // continue loop of slots
38923910
}

0 commit comments

Comments
 (0)