|
31 | 31 | #include <unordered_map> |
32 | 32 | #include <unordered_set> |
33 | 33 |
|
| 34 | +#define SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT 3 |
| 35 | + |
34 | 36 | using json = nlohmann::ordered_json; |
35 | 37 |
|
36 | 38 | constexpr int HTTP_POLLING_SECONDS = 1; |
@@ -693,10 +695,10 @@ struct completion_token_output { |
693 | 695 | }; |
694 | 696 |
|
695 | 697 | struct swa_checkpoint { |
696 | | - std::vector<uint8_t> data; |
697 | | - |
698 | 698 | llama_pos pos_min; |
699 | 699 | llama_pos pos_max; |
| 700 | + |
| 701 | + std::vector<uint8_t> data; |
700 | 702 | }; |
701 | 703 |
|
702 | 704 | struct server_task_result_cmpl_final : server_task_result { |
@@ -3300,50 +3302,56 @@ struct server_context { |
3300 | 3302 | slot.n_past = 0; |
3301 | 3303 | } |
3302 | 3304 |
|
| 3305 | + const auto n_swa = llama_model_n_swa(model); |
| 3306 | + |
3303 | 3307 | if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { |
3304 | 3308 | const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); |
3305 | 3309 | if (pos_min == -1) { |
3306 | 3310 | SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); |
3307 | 3311 | GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); |
3308 | 3312 | } |
3309 | 3313 |
|
3310 | | - const auto n_swa = llama_model_n_swa(model); |
3311 | | - if (pos_min > std::max(0, slot.n_past - n_swa)) { |
| 3314 | + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); |
| 3315 | + |
| 3316 | + if (pos_min > pos_min_thold) { |
3312 | 3317 | // search for a SWA checkpoint |
3313 | | - int ic = -1; |
3314 | | - int np = std::numeric_limits<int>::max(); |
3315 | | - for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) { |
3316 | | - const auto & cur = slot.swa_checkpoints[i]; |
3317 | | - if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) { |
3318 | | - const int p = std::max(0, slot.n_past - cur.pos_max); |
3319 | | - |
3320 | | - if (p < np) { |
3321 | | - ic = i; |
3322 | | - np = p; |
3323 | | - } |
| 3318 | + auto it = std::find_if( |
| 3319 | + slot.swa_checkpoints.rbegin(), |
| 3320 | + slot.swa_checkpoints.rend(), |
| 3321 | + [&](const auto & cur) { |
| 3322 | + return cur.pos_min <= pos_min_thold; |
3324 | 3323 | } |
3325 | | - } |
| 3324 | + ); |
3326 | 3325 |
|
3327 | | - if (ic == -1) { |
| 3326 | + if (it == slot.swa_checkpoints.rend()) { |
3328 | 3327 | SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); |
3329 | 3328 | SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", |
3330 | 3329 | "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); |
3331 | | - slot.n_past = 0; |
3332 | 3330 |
|
| 3331 | + slot.n_past = 0; |
3333 | 3332 | slot.swa_checkpoints.clear(); |
3334 | 3333 | } else { |
3335 | | - // erase all checkpoints after the one we are using |
3336 | | - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end()); |
3337 | | - |
3338 | 3334 | // restore the checkpoint |
3339 | | - const auto & cur = slot.swa_checkpoints[ic]; |
| 3335 | + const size_t swa_size = it->data.size(); |
| 3336 | + llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id); |
3340 | 3337 |
|
3341 | | - const size_t swa_size = cur.data.size(); |
3342 | | - llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id); |
| 3338 | + slot.n_past = std::min(slot.n_past, it->pos_max); |
3343 | 3339 |
|
3344 | | - slot.n_past = std::min(slot.n_past, cur.pos_max); |
| 3340 | + SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); |
| 3341 | + } |
| 3342 | + } |
| 3343 | + } |
3345 | 3344 |
|
3346 | | - SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); |
| 3345 | + if (n_swa > 0) { |
| 3346 | + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); |
| 3347 | + |
| 3348 | + // erase any checkpoints with pos_min > pos_min_thold |
| 3349 | + for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { |
| 3350 | + const auto & cur = slot.swa_checkpoints[i]; |
| 3351 | + if (cur.pos_min > pos_min_thold) { |
| 3352 | + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); |
| 3353 | + |
| 3354 | + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |
3347 | 3355 | } |
3348 | 3356 | } |
3349 | 3357 | } |
@@ -3559,23 +3567,29 @@ struct server_context { |
3559 | 3567 | // prompt evaluated for next-token prediction |
3560 | 3568 | slot.state = SLOT_STATE_GENERATING; |
3561 | 3569 |
|
3562 | | - // make a checkpoint |
| 3570 | + // make a checkpoint with the SWA memory |
3563 | 3571 | if (llama_model_n_swa(model) > 0) { |
3564 | | - if (slot.swa_checkpoints.size() > 8) { |
3565 | | - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); |
3566 | | - } |
| 3572 | + if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) { |
| 3573 | + { |
| 3574 | + const auto & cur = slot.swa_checkpoints.back(); |
3567 | 3575 |
|
3568 | | - auto & cur = slot.swa_checkpoints.emplace_back(); |
| 3576 | + SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); |
| 3577 | + } |
3569 | 3578 |
|
3570 | | - cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); |
3571 | | - cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); |
| 3579 | + slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); |
| 3580 | + } |
3572 | 3581 |
|
3573 | 3582 | const size_t swa_size = llama_state_seq_get_size(ctx, slot.id); |
3574 | | - cur.data.resize(swa_size); |
| 3583 | + |
| 3584 | + auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ |
| 3585 | + /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), |
| 3586 | + /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), |
| 3587 | + /*.data = */ std::vector<uint8_t>(swa_size), |
| 3588 | + }); |
3575 | 3589 |
|
3576 | 3590 | llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id); |
3577 | 3591 |
|
3578 | | - SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); |
| 3592 | + SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024); |
3579 | 3593 | } |
3580 | 3594 | } else if (slot.state != SLOT_STATE_GENERATING) { |
3581 | 3595 | continue; // continue loop of slots |
|
0 commit comments