@@ -692,6 +692,13 @@ struct completion_token_output {
692692 }
693693};
694694
695+ struct swa_checkpoint {
696+ std::vector<uint8_t > data;
697+
698+ llama_pos pos_min;
699+ llama_pos pos_max;
700+ };
701+
695702struct server_task_result_cmpl_final : server_task_result {
696703 int index = 0 ;
697704
@@ -1336,6 +1343,8 @@ struct server_slot {
13361343
13371344 std::vector<completion_token_output> generated_token_probs;
13381345
1346+ std::vector<swa_checkpoint> swa_checkpoints;
1347+
13391348 bool has_next_token = true ;
13401349 bool has_new_line = false ;
13411350 bool truncated = false ;
@@ -3300,10 +3309,42 @@ struct server_context {
33003309
33013310 const auto n_swa = llama_model_n_swa (model);
33023311 if (pos_min > std::max (0 , slot.n_past - n_swa)) {
3303- SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
3304- SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " ,
3305- " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
3306- slot.n_past = 0 ;
3312+ // search for a SWA checkpoint
3313+ int ic = -1 ;
3314+ int np = std::numeric_limits<int >::max ();
3315+ for (int i = 0 ; i < (int ) slot.swa_checkpoints .size (); i++) {
3316+ const auto & cur = slot.swa_checkpoints [i];
3317+ if (cur.pos_min <= std::max (0 , slot.n_past - n_swa)) {
3318+ const int p = std::max (0 , slot.n_past - cur.pos_max );
3319+
3320+ if (p < np) {
3321+ ic = i;
3322+ np = p;
3323+ }
3324+ }
3325+ }
3326+
3327+ if (ic == -1 ) {
3328+ SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
3329+ SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " ,
3330+ " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
3331+ slot.n_past = 0 ;
3332+
3333+ slot.swa_checkpoints .clear ();
3334+ } else {
3335+ // erase all checkpoints after the one we are using
3336+ slot.swa_checkpoints .erase (slot.swa_checkpoints .begin () + ic + 1 , slot.swa_checkpoints .end ());
3337+
3338+ // restore the checkpoint
3339+ const auto & cur = slot.swa_checkpoints [ic];
3340+
3341+ const size_t swa_size = cur.data .size ();
3342+ llama_state_seq_set_data (ctx, cur.data .data (), swa_size, slot.id );
3343+
3344+ slot.n_past = std::min (slot.n_past , cur.pos_max );
3345+
3346+ SLT_WRN (slot, " prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n " , cur.pos_min , cur.pos_max , (float ) swa_size / 1024 / 1024 );
3347+ }
33073348 }
33083349 }
33093350 }
@@ -3517,6 +3558,25 @@ struct server_context {
35173558
35183559 // prompt evaluated for next-token prediction
35193560 slot.state = SLOT_STATE_GENERATING;
3561+
3562+ // make a checkpoint
3563+ if (llama_model_n_swa (model) > 0 ) {
3564+ if (slot.swa_checkpoints .size () > 8 ) {
3565+ slot.swa_checkpoints .erase (slot.swa_checkpoints .begin ());
3566+ }
3567+
3568+ auto & cur = slot.swa_checkpoints .emplace_back ();
3569+
3570+ cur.pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
3571+ cur.pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id );
3572+
3573+ const size_t swa_size = llama_state_seq_get_size (ctx, slot.id );
3574+ cur.data .resize (swa_size);
3575+
3576+ llama_state_seq_get_data (ctx, cur.data .data (), swa_size, slot.id );
3577+
3578+ SLT_WRN (slot, " prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n " , cur.pos_min , cur.pos_max , (float ) swa_size / 1024 / 1024 );
3579+ }
35203580 } else if (slot.state != SLOT_STATE_GENERATING) {
35213581 continue ; // continue loop of slots
35223582 }
0 commit comments