@@ -3541,7 +3541,11 @@ struct server_context {
35413541 slot.n_past = 0 ;
35423542 }
35433543
3544- const auto n_swa = llama_model_n_swa (model);
3544+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+ const auto n_swa = std::max (1 , llama_model_n_swa (model));
3546+
3547+ // the largest pos_min required for a checkpoint to be useful
3548+ const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
35453549
35463550 if (slot.n_past > 0 && slot.n_past < (int ) slot.cache_tokens .size ()) {
35473551 const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3550,17 +3554,16 @@ struct server_context {
35503554 GGML_ABORT (" pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" );
35513555 }
35523556
3553- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3554-
3555- if (pos_min > pos_min_thold + 1 ) {
3557+ if (pos_min > pos_min_thold) {
35563558 SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573559
35583560 // search for a context checkpoint
35593561 const auto it = std::find_if (
35603562 slot.ctx_checkpoints .rbegin (),
35613563 slot.ctx_checkpoints .rend (),
35623564 [&](const auto & cur) {
3563- return cur.pos_min <= pos_min_thold;
3565+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+ return cur.pos_min < pos_min_thold;
35643567 }
35653568 );
35663569
@@ -3577,7 +3580,7 @@ struct server_context {
35773580 do_reset = true ;
35783581 // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35793582 } else {
3580- slot.n_past = std::min (slot.n_past , it->pos_max );
3583+ slot.n_past = std::min (slot.n_past , std::max ( it->pos_min + 1 , it-> pos_max ) );
35813584 SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35823585 }
35833586 }
@@ -3586,25 +3589,23 @@ struct server_context {
35863589 SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n " ,
35873590 " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
35883591 slot.n_past = 0 ;
3589- slot.ctx_checkpoints .clear ();
35903592 }
35913593 }
35923594 }
35933595
3594- if (n_swa > 0 ) {
3595- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3596-
3596+ {
35973597 // erase any checkpoints with pos_min > pos_min_thold
35983598 for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
35993599 const auto & cur = slot.ctx_checkpoints [i];
36003600 if (cur.pos_min > pos_min_thold) {
3601- slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36023601 SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3602+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36033603 }
36043604 }
36053605 }
36063606 }
36073607
3608+ // [TAG_PROMPT_LOGITS]
36083609 if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
36093610 SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " , slot.n_past , slot.n_prompt_tokens );
36103611 slot.n_past --;
0 commit comments