@@ -3676,6 +3676,20 @@ struct server_context {
3676
3676
alora_disabled_id = enabled_loras[0 ];
3677
3677
}
3678
3678
3679
+ bool do_checkpoint = params_base.n_ctx_checkpoints > 0 ;
3680
+
3681
+ // make a checkpoint of the parts of the memory that cannot be rolled back.
3682
+ // checkpoints are created only if:
3683
+ // - the model uses SWA and we are not using `swa_full`
3684
+ // - the model architecture is marked as recurrent or hybrid
3685
+ //
3686
+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
3687
+ do_checkpoint = do_checkpoint && (
3688
+ llama_model_is_recurrent (model) ||
3689
+ llama_model_is_hybrid (model) ||
3690
+ (llama_model_n_swa (model) > 0 && !params_base.swa_full )
3691
+ );
3692
+
3679
3693
// add prompt tokens for processing in the current batch
3680
3694
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
3681
3695
// get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
3700
3714
3701
3715
slot.n_prompt_tokens_processed ++;
3702
3716
slot.n_past ++;
3717
+
3718
+ // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
3719
+ if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64 ) {
3720
+ break ;
3721
+ }
3703
3722
}
3704
3723
3705
3724
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
3730
3749
slot.i_batch = batch.n_tokens - 1 ;
3731
3750
3732
3751
SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.n_tokens );
3752
+
3753
+ const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
3754
+ const auto pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id );
3755
+
3756
+ // no need for empty or small checkpoints
3757
+ do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64 );
3758
+
3759
+ // no need to create checkpoints that are too close together
3760
+ do_checkpoint = do_checkpoint && (slot.ctx_checkpoints .empty () || pos_max > slot.ctx_checkpoints .back ().pos_max + 64 );
3761
+
3762
+ if (do_checkpoint) {
3763
+ while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3764
+ // make room for the new checkpoint, if needed
3765
+ const auto & cur = slot.ctx_checkpoints .front ();
3766
+ SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3767
+ cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3768
+
3769
+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3770
+ }
3771
+
3772
+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3773
+
3774
+ auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3775
+ /* .pos_min = */ pos_min,
3776
+ /* .pos_max = */ pos_max,
3777
+ /* .data = */ std::vector<uint8_t >(checkpoint_size),
3778
+ });
3779
+
3780
+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3781
+
3782
+ SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3783
+ (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3784
+ }
3733
3785
}
3734
3786
}
3735
3787
@@ -3853,40 +3905,6 @@ struct server_context {
3853
3905
3854
3906
// prompt evaluated for next-token prediction
3855
3907
slot.state = SLOT_STATE_GENERATING;
3856
-
3857
- // make a checkpoint of the parts of the memory that cannot be rolled back.
3858
- // checkpoints are created only if:
3859
- // - the model uses SWA and we are not using `swa_full`
3860
- // - the model architecture is marked as recurrent or hybrid
3861
- //
3862
- // TODO: try to make this conditional on the context or the memory module, instead of the model type
3863
- const bool do_checkpoint =
3864
- (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865
- (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3866
-
3867
- if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3868
- while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869
- // make room for the new checkpoint, if needed
3870
- const auto & cur = slot.ctx_checkpoints .front ();
3871
- SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3872
- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3873
-
3874
- slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3875
- }
3876
-
3877
- const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3878
-
3879
- auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint{
3880
- /* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
3881
- /* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3882
- /* .data = */ std::vector<uint8_t >(checkpoint_size),
3883
- });
3884
-
3885
- llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3886
-
3887
- SLT_WRN (slot, " saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3888
- (int ) slot.ctx_checkpoints .size (), params_base.n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3889
- }
3890
3908
} else if (slot.state != SLOT_STATE_GENERATING) {
3891
3909
continue ; // continue loop of slots
3892
3910
}
0 commit comments