@@ -3573,7 +3573,7 @@ struct server_context {
35733573 if (!do_reset) {
35743574 // restore the context checkpoint
35753575 const size_t ctx_checkpoint_size = it->data .size ();
3576- const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY );
3576+ const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
35773577
35783578 if (n != ctx_checkpoint_size) {
35793579 SLT_ERR (slot, " failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
@@ -3598,7 +3598,7 @@ struct server_context {
35983598 for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
35993599 const auto & cur = slot.ctx_checkpoints [i];
36003600 if (cur.pos_min > pos_min_thold) {
3601- SLT_WRN (slot, " erased invalidated context checkpoint for SWA (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3601+ SLT_WRN (slot, " erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
36023602 slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
36033603 }
36043604 }
@@ -3854,32 +3854,35 @@ struct server_context {
38543854 // prompt evaluated for next-token prediction
38553855 slot.state = SLOT_STATE_GENERATING;
38563856
3857- // make a checkpoint of the parts of memory that cannot be rolled back.
3858- // checkpoints are needed only if:
3857+ // make a checkpoint of the parts of the memory that cannot be rolled back.
3858+ // checkpoints are created only if:
38593859 // - the model uses SWA and we are not using `swa_full`
38603860 // - the model architecture is marked as recurrent or hybrid
3861- bool do_checkpoint = (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3862- (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3861+ //
3862+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
3863+ const bool do_checkpoint =
3864+ (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865+ (llama_model_n_swa (model) > 0 && !params_base.swa_full );
38633866
38643867 if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3865- if (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3868+ while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
38663869 // make room for the new checkpoint, if needed
3867- const auto & cur = slot.ctx_checkpoints .back ();
3870+ const auto & cur = slot.ctx_checkpoints .front ();
38683871 SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
38693872 cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
38703873
38713874 slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
38723875 }
38733876
3874- const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY );
3877+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
38753878
38763879 auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
38773880 /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
38783881 /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
38793882 /* .data = */ std::vector<uint8_t >(checkpoint_size),
38803883 });
38813884
3882- llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_CHECKPOINT_ONLY );
3885+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
38833886
38843887 SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
38853888 (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
0 commit comments