@@ -764,7 +764,7 @@ struct completion_token_output {
764764 }
765765};
766766
767- struct swa_checkpoint {
767+ struct ctx_checkpoint {
768768 llama_pos pos_min;
769769 llama_pos pos_max;
770770
@@ -1460,7 +1460,7 @@ struct server_slot {
14601460
14611461 std::vector<completion_token_output> generated_token_probs;
14621462
1463- std::vector<swa_checkpoint> swa_checkpoints ;
1463+ std::vector<ctx_checkpoint> ctx_checkpoints ;
14641464
14651465 bool has_next_token = true ;
14661466 bool has_new_line = false ;
@@ -3555,38 +3555,38 @@ struct server_context {
35553555 if (pos_min > pos_min_thold) {
35563556 SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573557
3558- // search for a SWA checkpoint
3558+ // search for a context checkpoint
35593559 const auto it = std::find_if (
3560- slot.swa_checkpoints .rbegin (),
3561- slot.swa_checkpoints .rend (),
3560+ slot.ctx_checkpoints .rbegin (),
3561+ slot.ctx_checkpoints .rend (),
35623562 [&](const auto & cur) {
35633563 return cur.pos_min <= pos_min_thold;
35643564 }
35653565 );
35663566
3567- bool do_reset = it == slot.swa_checkpoints .rend ();
3567+ bool do_reset = it == slot.ctx_checkpoints .rend ();
3568+ printf (" [DEBUG] `do_reset` was set to `%s`\n " , do_reset ? " true" : " false" );
35683569
35693570 if (!do_reset) {
3570- // restore the checkpoint
3571- const size_t swa_size = it->data .size ();
3572- const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), swa_size , slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3571+ // restore the context checkpoint
3572+ const size_t ctx_checkpoint_size = it->data .size ();
3573+ const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size , slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY );
35733574
35743575 if (n != swa_size) {
3575- SLT_ERR (slot, " failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3576+ SLT_ERR (slot, " failed to restore context checkpoint ( pos_min = %d, pos_max = %d, size = %.3f MiB) \n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35763577 do_reset = true ;
3578+ printf (" [DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint" );
35773579 } else {
35783580 slot.n_past = std::min (slot.n_past , it->pos_max );
3579-
3580- SLT_WRN (slot, " SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3581+ SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35813582 }
35823583 }
35833584
35843585 if (do_reset) {
35853586 SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " ,
35863587 " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
3587-
35883588 slot.n_past = 0 ;
3589- slot.swa_checkpoints .clear ();
3589+ slot.ctx_checkpoints .clear ();
35903590 }
35913591 }
35923592 }
@@ -3595,21 +3595,20 @@ struct server_context {
35953595 const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
35963596
35973597 // erase any checkpoints with pos_min > pos_min_thold
3598- for (int i = (int ) slot.swa_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599- const auto & cur = slot.swa_checkpoints [i];
3598+ for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599+ const auto & cur = slot.ctx_checkpoints [i];
36003600 if (cur.pos_min > pos_min_thold) {
3601- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin () + i);
3602-
3603- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3601+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
3602+ SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
36043603 }
36053604 }
36063605 }
36073606 }
36083607
36093608 if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
3610- SLT_WRN (slot, " need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
3611-
3609+ SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " , slot.n_past , slot.n_prompt_tokens );
36123610 slot.n_past --;
3611+ SLT_WRN (slot, " n_past was set to %d\n " , slot.n_past );
36133612 }
36143613
36153614 slot.n_prompt_tokens_cache = slot.n_past ;
@@ -3623,17 +3622,18 @@ struct server_context {
36233622 }
36243623 }
36253624
3626- // keep only the common part
3625+ // truncate any tokens that are beyond n_past for this slot
36273626 if (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 )) {
3628- // could not partially delete (likely using a non-Transformer model)
3627+ SLT_WRN (slot, " failed to truncate tokens beyond n_past = %d \n " , slot. n_past );
36293628 llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
36303629
36313630 // there is no common part left
36323631 slot.n_past = 0 ;
36333632 slot.n_prompt_tokens_cache = 0 ;
3633+ printf (" [DEBUG] we had no choice but to truncate all tokens from this slot :( very sad" );
36343634 }
36353635
3636- SLT_INF (slot, " kv cache rm [%d, end) \n " , slot.n_past );
3636+ SLT_INF (slot, " n_past = %d \n " , slot.n_past );
36373637
36383638 // remove the non-common part from the cache
36393639 slot.cache_tokens .keep_first (slot.n_past );
@@ -3854,37 +3854,35 @@ struct server_context {
38543854 // prompt evaluated for next-token prediction
38553855 slot.state = SLOT_STATE_GENERATING;
38563856
3857- // make a checkpoint with the SWA memory
3858- // checkpoints are needed only if we are not using "--swa-full"
3859- if (llama_model_n_swa (model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0 ) {
3860- if (slot.swa_checkpoints .size () >= (size_t ) params_base.n_swa_checkpoints ) {
3861- {
3862- const auto & cur = slot.swa_checkpoints .back ();
3863-
3864- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " ,
3865- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3866- }
3867-
3868- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin ());
3857+ // make a checkpoint of the parts of memory that cannot be rolled back.
3858+ // checkpoints are needed only if:
3859+ // - the model uses SWA and we are not using `swa_full`
3860+ // - the model architecture is marked as recurrent or hybrid
3861+ bool do_checkpoint = (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3862+ (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3863+
3864+ if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3865+ if (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3866+ // make room for the new checkpoint, if needed
3867+ const auto & cur = slot.ctx_checkpoints .back ();
3868+ SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3869+ cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3870+
3871+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
38693872 }
38703873
3871- const size_t swa_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3874+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY );
38723875
3873- auto & cur = slot.swa_checkpoints .emplace_back (swa_checkpoint {
3876+ auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint {
38743877 /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
38753878 /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3876- /* .data = */ std::vector<uint8_t >(swa_size ),
3879+ /* .data = */ std::vector<uint8_t >(checkpoint_size ),
38773880 });
38783881
3879- llama_state_seq_get_data_ext (ctx, cur.data .data (), swa_size, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880-
3881- float size_total = 0 .0f ;
3882- for (const auto & checkpoint : slot.swa_checkpoints ) {
3883- size_total += (float ) checkpoint.data .size () / 1024 / 1024 ;
3884- }
3882+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY);
38853883
3886- SLT_WRN (slot, " SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d ( %.3f MiB)\n " ,
3887- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 , (int ) slot. swa_checkpoints . size (), params_base. n_swa_checkpoints , size_total );
3884+ SLT_WRN (slot, " saved context checkpoint %d of %d ( pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3885+ slot. ctx_checkpoints . size (), params_base. n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 , (int ) );
38883886 }
38893887 } else if (slot.state != SLOT_STATE_GENERATING) {
38903888 continue ; // continue loop of slots
0 commit comments