Skip to content

Commit 487b922

Browse files
committed
cont : server clean-up
1 parent 96db966 commit 487b922

File tree

1 file changed

+49
-35
lines changed

1 file changed

+49
-35
lines changed

tools/server/server.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include <unordered_map>
3232
#include <unordered_set>
3333

34+
#define SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT 3
35+
3436
using json = nlohmann::ordered_json;
3537

3638
constexpr int HTTP_POLLING_SECONDS = 1;
@@ -693,10 +695,10 @@ struct completion_token_output {
693695
};
694696

695697
struct swa_checkpoint {
696-
std::vector<uint8_t> data;
697-
698698
llama_pos pos_min;
699699
llama_pos pos_max;
700+
701+
std::vector<uint8_t> data;
700702
};
701703

702704
struct server_task_result_cmpl_final : server_task_result {
@@ -3300,50 +3302,56 @@ struct server_context {
33003302
slot.n_past = 0;
33013303
}
33023304

3305+
const auto n_swa = llama_model_n_swa(model);
3306+
33033307
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
33043308
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
33053309
if (pos_min == -1) {
33063310
SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
33073311
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
33083312
}
33093313

3310-
const auto n_swa = llama_model_n_swa(model);
3311-
if (pos_min > std::max(0, slot.n_past - n_swa)) {
3314+
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3315+
3316+
if (pos_min > pos_min_thold) {
33123317
// search for a SWA checkpoint
3313-
int ic = -1;
3314-
int np = std::numeric_limits<int>::max();
3315-
for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) {
3316-
const auto & cur = slot.swa_checkpoints[i];
3317-
if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) {
3318-
const int p = std::max(0, slot.n_past - cur.pos_max);
3319-
3320-
if (p < np) {
3321-
ic = i;
3322-
np = p;
3323-
}
3318+
auto it = std::find_if(
3319+
slot.swa_checkpoints.rbegin(),
3320+
slot.swa_checkpoints.rend(),
3321+
[&](const auto & cur) {
3322+
return cur.pos_min <= pos_min_thold;
33243323
}
3325-
}
3324+
);
33263325

3327-
if (ic == -1) {
3326+
if (it == slot.swa_checkpoints.rend()) {
33283327
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
33293328
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
33303329
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3331-
slot.n_past = 0;
33323330

3331+
slot.n_past = 0;
33333332
slot.swa_checkpoints.clear();
33343333
} else {
3335-
// erase all checkpoints after the one we are using
3336-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end());
3337-
33383334
// restore the checkpoint
3339-
const auto & cur = slot.swa_checkpoints[ic];
3335+
const size_t swa_size = it->data.size();
3336+
llama_state_seq_set_data(ctx, it->data.data(), swa_size, slot.id);
33403337

3341-
const size_t swa_size = cur.data.size();
3342-
llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id);
3338+
slot.n_past = std::min(slot.n_past, it->pos_max);
33433339

3344-
slot.n_past = std::min(slot.n_past, cur.pos_max);
3340+
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
3341+
}
3342+
}
3343+
}
33453344

3346-
SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
3345+
if (n_swa > 0) {
3346+
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3347+
3348+
// erase any checkpoints with pos_min > pos_min_thold
3349+
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
3350+
const auto & cur = slot.swa_checkpoints[i];
3351+
if (cur.pos_min > pos_min_thold) {
3352+
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
3353+
3354+
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
33473355
}
33483356
}
33493357
}
@@ -3559,23 +3567,29 @@ struct server_context {
35593567
// prompt evaluated for next-token prediction
35603568
slot.state = SLOT_STATE_GENERATING;
35613569

3562-
// make a checkpoint
3570+
// make a checkpoint with the SWA memory
35633571
if (llama_model_n_swa(model) > 0) {
3564-
if (slot.swa_checkpoints.size() > 8) {
3565-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
3566-
}
3572+
if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) {
3573+
{
3574+
const auto & cur = slot.swa_checkpoints.back();
35673575

3568-
auto & cur = slot.swa_checkpoints.emplace_back();
3576+
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3577+
}
35693578

3570-
cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
3571-
cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
3579+
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
3580+
}
35723581

35733582
const size_t swa_size = llama_state_seq_get_size(ctx, slot.id);
3574-
cur.data.resize(swa_size);
3583+
3584+
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
3585+
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
3586+
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
3587+
/*.data = */ std::vector<uint8_t>(swa_size),
3588+
});
35753589

35763590
llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id);
35773591

3578-
SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
3592+
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
35793593
}
35803594
} else if (slot.state != SLOT_STATE_GENERATING) {
35813595
continue; // continue loop of slots

0 commit comments

Comments
 (0)