Skip to content

Commit 52b775e

Browse files
committed
server : configure number of SWA checkpoints with CLI arg
ggml-ci
1 parent c2b5cfb commit 52b775e

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15061506
params.swa_full = true;
15071507
}
15081508
).set_env("LLAMA_ARG_SWA_FULL"));
1509+
add_opt(common_arg(
1510+
{"--swa-checkpoints"}, "N",
1511+
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
1512+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
1513+
[](common_params & params, int value) {
1514+
params.n_swa_checkpoints = value;
1515+
}
1516+
).set_env("LLAMA_ARG_SWA_CHECKPOINTS"));
15091517
add_opt(common_arg(
15101518
{"--kv-unified", "-kvu"},
15111519
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

common/common.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,12 @@ struct common_params {
385385
std::string cls_sep = "\t"; // separator of classification sequences
386386

387387
// server params
388-
int32_t port = 8080; // server listens on this network port
389-
int32_t timeout_read = 600; // http read timeout in seconds
390-
int32_t timeout_write = timeout_read; // http write timeout in seconds
391-
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
392-
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
388+
int32_t port = 8080; // server listens on this network port
389+
int32_t timeout_read = 600; // http read timeout in seconds
390+
int32_t timeout_write = timeout_read; // http write timeout in seconds
391+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
392+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
393+
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
393394

394395
std::string hostname = "127.0.0.1";
395396
std::string public_path = ""; // NOLINT

tools/server/server.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
#include <unordered_map>
3232
#include <unordered_set>
3333

34-
#define SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT 3
35-
3634
using json = nlohmann::ordered_json;
3735

3836
constexpr int HTTP_POLLING_SECONDS = 1;
@@ -3579,12 +3577,13 @@ struct server_context {
35793577

35803578
// make a checkpoint with the SWA memory
35813579
// checkpoints are needed only if we are not using "--swa-full"
3582-
if (llama_model_n_swa(model) > 0 && !params_base.swa_full) {
3583-
if (slot.swa_checkpoints.size() >= SERVER_MAX_SWA_CHECKPOINTS_PER_SLOT) {
3580+
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
3581+
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
35843582
{
35853583
const auto & cur = slot.swa_checkpoints.back();
35863584

3587-
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3585+
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
3586+
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
35883587
}
35893588

35903589
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
@@ -3600,7 +3599,13 @@ struct server_context {
36003599

36013600
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
36023601

3603-
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
3602+
float size_total = 0.0f;
3603+
for (const auto & checkpoint : slot.swa_checkpoints) {
3604+
size_total += (float) checkpoint.data.size() / 1024 / 1024;
3605+
}
3606+
3607+
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
3608+
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
36043609
}
36053610
} else if (slot.state != SLOT_STATE_GENERATING) {
36063611
continue; // continue loop of slots

0 commit comments

Comments
 (0)