Skip to content

Commit 6fc5bcd

Browse files
committed
server : cleanup and fixes
1 parent 85d5053 commit 6fc5bcd

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

include/llama.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,11 @@ extern "C" {
794794
size_t n_token_capacity,
795795
size_t * n_token_count_out);
796796

797-
#define LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY 1
797+
// for backwards-compat
798+
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
799+
800+
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
801+
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
798802

799803
typedef uint32_t llama_state_seq_flags;
800804

src/llama-kv-cache-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
220220
}
221221

222222
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
223-
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
223+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
224224
kv_base->state_write(io, seq_id, flags);
225225
}
226226

227227
kv_swa->state_write(io, seq_id, flags);
228228
}
229229

230230
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
231-
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
231+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
232232
kv_base->state_read(io, seq_id, flags);
233233
}
234234

src/llama-memory-hybrid.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
175175
}
176176

177177
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
178-
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
178+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
179179
mem_attn->state_write(io, seq_id, flags);
180180
}
181181
mem_recr->state_write(io, seq_id, flags);
182182
}
183183

184184
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
185-
if ((flags & LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY) == 0) {
185+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
186186
mem_attn->state_read(io, seq_id, flags);
187187
}
188188
mem_recr->state_read(io, seq_id, flags);

src/llama-memory-recurrent.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,6 @@ size_t llama_memory_recurrent::size_s_bytes() const {
692692
}
693693

694694
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
695-
// the LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY flag is acknowledged but does not change
696-
// behavior here, as there is no notion of a partial state for a recurrent context
697695
GGML_UNUSED(flags);
698696

699697
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
@@ -734,8 +732,6 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
734732
}
735733

736734
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
737-
// the LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY flag is acknowledged but does not change
738-
// behavior here, as there is no notion of a partial state for a recurrent context
739735
GGML_UNUSED(flags);
740736

741737
uint32_t cell_count;

tools/server/server.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3573,7 +3573,7 @@ struct server_context {
35733573
if (!do_reset) {
35743574
// restore the context checkpoint
35753575
const size_t ctx_checkpoint_size = it->data.size();
3576-
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
3576+
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
35773577

35783578
if (n != ctx_checkpoint_size) {
35793579
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
@@ -3598,7 +3598,7 @@ struct server_context {
35983598
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
35993599
const auto & cur = slot.ctx_checkpoints[i];
36003600
if (cur.pos_min > pos_min_thold) {
3601-
SLT_WRN(slot, "erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3601+
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
36023602
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
36033603
}
36043604
}
@@ -3854,32 +3854,35 @@ struct server_context {
38543854
// prompt evaluated for next-token prediction
38553855
slot.state = SLOT_STATE_GENERATING;
38563856

3857-
// make a checkpoint of the parts of memory that cannot be rolled back.
3858-
// checkpoints are needed only if:
3857+
// make a checkpoint of the parts of the memory that cannot be rolled back.
3858+
// checkpoints are created only if:
38593859
// - the model uses SWA and we are not using `swa_full`
38603860
// - the model architecture is marked as recurrent or hybrid
3861-
bool do_checkpoint = (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
3862-
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
3861+
//
3862+
// TODO: try to make this conditional on the context or the memory module, instead of the model type
3863+
const bool do_checkpoint =
3864+
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
3865+
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
38633866

38643867
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
3865-
if (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
3868+
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
38663869
// make room for the new checkpoint, if needed
3867-
const auto & cur = slot.ctx_checkpoints.back();
3870+
const auto & cur = slot.ctx_checkpoints.front();
38683871
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
38693872
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
38703873

38713874
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
38723875
}
38733876

3874-
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
3877+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38753878

38763879
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
38773880
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
38783881
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
38793882
/*.data = */ std::vector<uint8_t>(checkpoint_size),
38803883
});
38813884

3882-
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
3885+
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38833886

38843887
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
38853888
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

0 commit comments

Comments
 (0)