Skip to content

Commit d32e03f

Browse files
authored
server : add SWA checkpoints (#15293)
* server : add SWA checkpoints ggml-ci * cont : server clean-up * server : handle state restore fails * llama : add extended llama_state_seq_ API * server : do not make checkpoints if --swa-full ggml-ci * llama : remove flags value for NONE * server : configure number of SWA checkpoints with CLI arg ggml-ci * args : fix scope of new argument
1 parent 3973163 commit d32e03f

15 files changed

+205
-53
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,6 +1507,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15071507
params.swa_full = true;
15081508
}
15091509
).set_env("LLAMA_ARG_SWA_FULL"));
1510+
add_opt(common_arg(
1511+
{"--swa-checkpoints"}, "N",
1512+
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
1513+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
1514+
[](common_params & params, int value) {
1515+
params.n_swa_checkpoints = value;
1516+
}
1517+
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
15101518
add_opt(common_arg(
15111519
{"--kv-unified", "-kvu"},
15121520
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

common/common.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,12 @@ struct common_params {
413413
std::string cls_sep = "\t"; // separator of classification sequences
414414

415415
// server params
416-
int32_t port = 8080; // server listens on this network port
417-
int32_t timeout_read = 600; // http read timeout in seconds
418-
int32_t timeout_write = timeout_read; // http write timeout in seconds
419-
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
420-
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
416+
int32_t port = 8080; // server listens on this network port
417+
int32_t timeout_read = 600; // http read timeout in seconds
418+
int32_t timeout_write = timeout_read; // http write timeout in seconds
419+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
420+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
421+
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
421422

422423
std::string hostname = "127.0.0.1";
423424
std::string public_path = ""; // NOLINT

include/llama.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,29 @@ extern "C" {
870870
size_t n_token_capacity,
871871
size_t * n_token_count_out);
872872

873+
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
874+
875+
typedef uint32_t llama_state_seq_flags;
876+
877+
LLAMA_API size_t llama_state_seq_get_size_ext(
878+
struct llama_context * ctx,
879+
llama_seq_id seq_id,
880+
llama_state_seq_flags flags);
881+
882+
LLAMA_API size_t llama_state_seq_get_data_ext(
883+
struct llama_context * ctx,
884+
uint8_t * dst,
885+
size_t size,
886+
llama_seq_id seq_id,
887+
llama_state_seq_flags flags);
888+
889+
LLAMA_API size_t llama_state_seq_set_data_ext(
890+
struct llama_context * ctx,
891+
const uint8_t * src,
892+
size_t size,
893+
llama_seq_id dest_seq_id,
894+
llama_state_seq_flags flags);
895+
873896
//
874897
// Decoding
875898
//

src/llama-context.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
16571657
}
16581658
}
16591659

1660-
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
1660+
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
16611661
llama_io_write_dummy io;
16621662
try {
1663-
return state_seq_write_data(io, seq_id);
1663+
return state_seq_write_data(io, seq_id, flags);
16641664
} catch (const std::exception & err) {
16651665
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
16661666
return 0;
16671667
}
16681668
}
16691669

1670-
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
1670+
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
16711671
llama_io_write_buffer io(dst, size);
16721672
try {
1673-
return state_seq_write_data(io, seq_id);
1673+
return state_seq_write_data(io, seq_id, flags);
16741674
} catch (const std::exception & err) {
16751675
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
16761676
return 0;
16771677
}
16781678
}
16791679

1680-
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
1680+
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
16811681
llama_io_read_buffer io(src, size);
16821682
try {
1683-
return state_seq_read_data(io, seq_id);
1683+
return state_seq_read_data(io, seq_id, flags);
16841684
} catch (const std::exception & err) {
16851685
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
16861686
return 0;
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
17781778
{
17791779
const size_t state_size = file.size() - file.tell();
17801780
llama_io_read_file io(&file);
1781-
const size_t nread = state_seq_read_data(io, seq_id);
1781+
const size_t nread = state_seq_read_data(io, seq_id, 0);
17821782
if (!nread) {
17831783
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
17841784
return 0;
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
18021802

18031803
// save the context state using stream saving
18041804
llama_io_write_file io(&file);
1805-
state_seq_write_data(io, seq_id);
1805+
state_seq_write_data(io, seq_id, 0);
18061806

18071807
const size_t res = file.tell();
18081808
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
19711971
return io.n_bytes();
19721972
}
19731973

1974-
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
1974+
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
19751975
GGML_UNUSED(seq_id);
19761976

19771977
if (memory) {
1978-
memory->state_write(io, seq_id);
1978+
memory->state_write(io, seq_id, flags);
19791979
}
19801980

19811981
return io.n_bytes();
19821982
}
19831983

1984-
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
1984+
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
19851985
GGML_UNUSED(seq_id);
19861986

19871987
if (memory) {
1988-
memory->state_read(io, seq_id);
1988+
memory->state_read(io, seq_id, flags);
19891989
}
19901990

19911991
return io.n_bytes();
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
28012801
}
28022802

28032803
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
2804-
return ctx->state_seq_get_size(seq_id);
2804+
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
28052805
}
28062806

28072807
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
2808+
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
2809+
}
2810+
2811+
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2812+
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
2813+
}
2814+
2815+
size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
2816+
return ctx->state_seq_get_size(seq_id, flags);
2817+
}
2818+
2819+
size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
28082820
ctx->synchronize();
28092821

2810-
return ctx->state_seq_get_data(seq_id, dst, size);
2822+
return ctx->state_seq_get_data(seq_id, dst, size, flags);
28112823
}
28122824

2813-
size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
2825+
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
28142826
ctx->synchronize();
28152827

2816-
return ctx->state_seq_set_data(seq_id, src, size);
2828+
return ctx->state_seq_set_data(seq_id, src, size, flags);
28172829
}
28182830

28192831
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {

src/llama-context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ struct llama_context {
111111
size_t state_get_data( uint8_t * dst, size_t size);
112112
size_t state_set_data(const uint8_t * src, size_t size);
113113

114-
size_t state_seq_get_size(llama_seq_id seq_id);
115-
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
116-
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
114+
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
115+
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
116+
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
117117

118118
bool state_load_file(
119119
const char * filepath,
@@ -213,8 +213,8 @@ struct llama_context {
213213
size_t state_write_data(llama_io_write_i & io);
214214
size_t state_read_data (llama_io_read_i & io);
215215

216-
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
217-
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
216+
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
217+
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
218218

219219
//
220220
// members

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
194194
return kv_base->get_size() == kv_swa->get_size();
195195
}
196196

197-
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
198-
kv_base->state_write(io, seq_id);
199-
kv_swa ->state_write(io, seq_id);
197+
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
198+
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
199+
kv_base->state_write(io, seq_id, flags);
200+
}
201+
202+
kv_swa->state_write(io, seq_id, flags);
200203
}
201204

202-
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
203-
kv_base->state_read(io, seq_id);
204-
kv_swa ->state_read(io, seq_id);
205+
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
206+
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
207+
kv_base->state_read(io, seq_id, flags);
208+
}
209+
210+
kv_swa->state_read(io, seq_id, flags);
205211
}
206212

207213
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {

src/llama-kv-cache-unified-iswa.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
5656

5757
// state write/load
5858

59-
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
60-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
59+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
60+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
6161

6262
//
6363
// llama_kv_cache_unified_iswa specific API

src/llama-kv-cache-unified.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
18281828
return false;
18291829
}
18301830

1831-
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
1831+
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
1832+
GGML_UNUSED(flags);
1833+
18321834
io.write(&n_stream, sizeof(n_stream));
18331835

18341836
for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
18791881
}
18801882
}
18811883

1882-
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
1884+
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
1885+
GGML_UNUSED(flags);
1886+
18831887
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
18841888

18851889
uint32_t n_stream_cur;

src/llama-kv-cache-unified.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ class llama_kv_cache_unified : public llama_memory_i {
136136

137137
// state write/load
138138

139-
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
140-
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
139+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
140+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
141141

142142
//
143143
// llama_kv_cache_unified specific API

src/llama-memory-hybrid.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
165165
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
166166
}
167167

168-
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
168+
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
169+
GGML_UNUSED(flags);
170+
169171
mem_attn->state_write(io, seq_id);
170172
mem_recr->state_write(io, seq_id);
171173
}
172174

173-
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
175+
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
176+
GGML_UNUSED(flags);
177+
174178
mem_attn->state_read(io, seq_id);
175179
mem_recr->state_read(io, seq_id);
176180
}

0 commit comments

Comments
 (0)