@@ -764,7 +764,7 @@ struct completion_token_output {
764764 }
765765};
766766
767- struct swa_checkpoint {
767+ struct ctx_checkpoint {
768768 llama_pos pos_min;
769769 llama_pos pos_max;
770770
@@ -1460,7 +1460,7 @@ struct server_slot {
14601460
14611461 std::vector<completion_token_output> generated_token_probs;
14621462
1463- std::vector<swa_checkpoint> swa_checkpoints ;
1463+ std::vector<ctx_checkpoint> ctx_checkpoints ;
14641464
14651465 bool has_next_token = true ;
14661466 bool has_new_line = false ;
@@ -3541,7 +3541,11 @@ struct server_context {
35413541 slot.n_past = 0 ;
35423542 }
35433543
3544- const auto n_swa = llama_model_n_swa (model);
3544+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+ const auto n_swa = std::max (1 , llama_model_n_swa (model));
3546+
3547+ // the largest pos_min required for a checkpoint to be useful
3548+ const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
35453549
35463550 if (slot.n_past > 0 && slot.n_past < (int ) slot.cache_tokens .size ()) {
35473551 const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3550,66 +3554,62 @@ struct server_context {
35503554 GGML_ABORT (" pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" );
35513555 }
35523556
3553- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3554-
35553557 if (pos_min > pos_min_thold) {
35563558 SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
35573559
3558- // search for a SWA checkpoint
3560+ // search for a context checkpoint
35593561 const auto it = std::find_if (
3560- slot.swa_checkpoints .rbegin (),
3561- slot.swa_checkpoints .rend (),
3562+ slot.ctx_checkpoints .rbegin (),
3563+ slot.ctx_checkpoints .rend (),
35623564 [&](const auto & cur) {
3563- return cur.pos_min <= pos_min_thold;
3565+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+ return cur.pos_min < pos_min_thold;
35643567 }
35653568 );
35663569
3567- bool do_reset = it == slot.swa_checkpoints .rend ();
3570+ bool do_reset = it == slot.ctx_checkpoints .rend ();
3571+ // printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
35683572
35693573 if (!do_reset) {
3570- // restore the checkpoint
3571- const size_t swa_size = it->data .size ();
3572- const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), swa_size , slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3574+ // restore the context checkpoint
3575+ const size_t ctx_checkpoint_size = it->data .size ();
3576+ const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size , slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
35733577
3574- if (n != swa_size ) {
3575- SLT_ERR (slot, " failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3578+ if (n != ctx_checkpoint_size ) {
3579+ SLT_ERR (slot, " failed to restore context checkpoint ( pos_min = %d, pos_max = %d, size = %.3f MiB) \n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35763580 do_reset = true ;
3581+ // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35773582 } else {
3578- slot.n_past = std::min (slot.n_past , it->pos_max );
3579-
3580- SLT_WRN (slot, " SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3583+ slot.n_past = std::min (slot.n_past , std::max (it->pos_min + 1 , it->pos_max ));
3584+ SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
35813585 }
35823586 }
35833587
35843588 if (do_reset) {
3585- SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " ,
3589+ SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory , see %s)\n " ,
35863590 " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
3587-
35883591 slot.n_past = 0 ;
3589- slot.swa_checkpoints .clear ();
35903592 }
35913593 }
35923594 }
35933595
3594- if (n_swa > 0 ) {
3595- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3596-
3596+ {
35973597 // erase any checkpoints with pos_min > pos_min_thold
3598- for (int i = (int ) slot.swa_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599- const auto & cur = slot.swa_checkpoints [i];
3598+ for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599+ const auto & cur = slot.ctx_checkpoints [i];
36003600 if (cur.pos_min > pos_min_thold) {
3601- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin () + i);
3602-
3603- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3601+ SLT_WRN (slot, " erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3602+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36043603 }
36053604 }
36063605 }
36073606 }
36083607
3608+ // [TAG_PROMPT_LOGITS]
36093609 if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
3610- SLT_WRN (slot, " need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
3611-
3610+ SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " , slot.n_past , slot.n_prompt_tokens );
36123611 slot.n_past --;
3612+ SLT_WRN (slot, " n_past was set to %d\n " , slot.n_past );
36133613 }
36143614
36153615 slot.n_prompt_tokens_cache = slot.n_past ;
@@ -3623,17 +3623,17 @@ struct server_context {
36233623 }
36243624 }
36253625
3626- // keep only the common part
3626+ // truncate any tokens that are beyond n_past for this slot
36273627 if (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 )) {
3628- // could not partially delete (likely using a non-Transformer model)
3628+ SLT_WRN (slot, " failed to truncate tokens beyond n_past = %d \n " , slot. n_past );
36293629 llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
36303630
36313631 // there is no common part left
36323632 slot.n_past = 0 ;
36333633 slot.n_prompt_tokens_cache = 0 ;
36343634 }
36353635
3636- SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
3636+ SLT_INF (slot, " n_past = %d, memory_seq_rm [%d, end)\n " , slot. n_past , slot.n_past );
36373637
36383638 // remove the non-common part from the cache
36393639 slot.cache_tokens .keep_first (slot.n_past );
@@ -3854,37 +3854,38 @@ struct server_context {
38543854 // prompt evaluated for next-token prediction
38553855 slot.state = SLOT_STATE_GENERATING;
38563856
3857- // make a checkpoint with the SWA memory
3858- // checkpoints are needed only if we are not using "--swa-full"
3859- if (llama_model_n_swa (model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0 ) {
3860- if (slot.swa_checkpoints .size () >= (size_t ) params_base.n_swa_checkpoints ) {
3861- {
3862- const auto & cur = slot.swa_checkpoints .back ();
3863-
3864- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " ,
3865- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3866- }
3867-
3868- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin ());
3857+ // make a checkpoint of the parts of the memory that cannot be rolled back.
3858+ // checkpoints are created only if:
3859+ // - the model uses SWA and we are not using `swa_full`
3860+ // - the model architecture is marked as recurrent or hybrid
3861+ //
3862+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
3863+ const bool do_checkpoint =
3864+ (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865+ (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3866+
3867+ if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3868+ while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869+ // make room for the new checkpoint, if needed
3870+ const auto & cur = slot.ctx_checkpoints .front ();
3871+ SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3872+ cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3873+
3874+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
38693875 }
38703876
3871- const size_t swa_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3877+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
38723878
3873- auto & cur = slot.swa_checkpoints .emplace_back (swa_checkpoint {
3879+ auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint {
38743880 /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
38753881 /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3876- /* .data = */ std::vector<uint8_t >(swa_size ),
3882+ /* .data = */ std::vector<uint8_t >(checkpoint_size ),
38773883 });
38783884
3879- llama_state_seq_get_data_ext (ctx, cur.data .data (), swa_size, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880-
3881- float size_total = 0 .0f ;
3882- for (const auto & checkpoint : slot.swa_checkpoints ) {
3883- size_total += (float ) checkpoint.data .size () / 1024 / 1024 ;
3884- }
3885+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38853886
3886- SLT_WRN (slot, " SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d ( %.3f MiB)\n " ,
3887- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 , ( int ) slot. swa_checkpoints . size (), params_base. n_swa_checkpoints , size_total );
3887+ SLT_WRN (slot, " saved context checkpoint %d of %d ( pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3888+ ( int ) slot. ctx_checkpoints . size (), params_base. n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
38883889 }
38893890 } else if (slot.state != SLOT_STATE_GENERATING) {
38903891 continue ; // continue loop of slots
0 commit comments