Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--swa-checkpoints"}, "N",
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
[](common_params & params, int value) {
params.n_swa_checkpoints = value;
}
).set_env("LLAMA_ARG_SWA_CHECKPOINTS"));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
Expand Down
11 changes: 6 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,12 @@ struct common_params {
std::string cls_sep = "\t"; // separator of classification sequences

// server params
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t port = 8080; // server listens on this network port
int32_t timeout_read = 600; // http read timeout in seconds
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot

std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
Expand Down
23 changes: 23 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,29 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);

#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1

typedef uint32_t llama_state_seq_flags;

LLAMA_API size_t llama_state_seq_get_size_ext(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_state_seq_flags flags);

LLAMA_API size_t llama_state_seq_get_data_ext(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id,
llama_state_seq_flags flags);

LLAMA_API size_t llama_state_seq_set_data_ext(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);

//
// Decoding
//
Expand Down
44 changes: 28 additions & 16 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
}
}

size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_io_write_dummy io;
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}

size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
llama_io_write_buffer io(dst, size);
try {
return state_seq_write_data(io, seq_id);
return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}

size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
llama_io_read_buffer io(src, size);
try {
return state_seq_read_data(io, seq_id);
return state_seq_read_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
Expand Down Expand Up @@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
const size_t nread = state_seq_read_data(io, seq_id);
const size_t nread = state_seq_read_data(io, seq_id, 0);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
Expand All @@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file

// save the context state using stream saving
llama_io_write_file io(&file);
state_seq_write_data(io, seq_id);
state_seq_write_data(io, seq_id, 0);

const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
Expand Down Expand Up @@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
return io.n_bytes();
}

size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);

if (memory) {
memory->state_write(io, seq_id);
memory->state_write(io, seq_id, flags);
}

return io.n_bytes();
}

size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);

if (memory) {
memory->state_read(io, seq_id);
memory->state_read(io, seq_id, flags);
}

return io.n_bytes();
Expand Down Expand Up @@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
}

size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
return ctx->state_seq_get_size(seq_id);
return llama_state_seq_get_size_ext(ctx, seq_id, 0);
}

size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
}

size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
}

size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
return ctx->state_seq_get_size(seq_id, flags);
}

size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();

return ctx->state_seq_get_data(seq_id, dst, size);
return ctx->state_seq_get_data(seq_id, dst, size, flags);
}

size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();

return ctx->state_seq_set_data(seq_id, src, size);
return ctx->state_seq_set_data(seq_id, src, size, flags);
}

size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
Expand Down
10 changes: 5 additions & 5 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ struct llama_context {
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);

size_t state_seq_get_size(llama_seq_id seq_id);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);

bool state_load_file(
const char * filepath,
Expand Down Expand Up @@ -212,8 +212,8 @@ struct llama_context {
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);

size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);

//
// members
Expand Down
18 changes: 12 additions & 6 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}

void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_base->state_write(io, seq_id);
kv_swa ->state_write(io, seq_id);
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}

kv_swa->state_write(io, seq_id, flags);
}

void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_base->state_read(io, seq_id);
kv_swa ->state_read(io, seq_id);
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

kv_swa->state_read(io, seq_id, flags);
}

llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_kv_cache_unified_iswa specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
return false;
}

void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

io.write(&n_stream, sizeof(n_stream));

for (uint32_t s = 0; s < n_stream; ++s) {
Expand Down Expand Up @@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
}
}

void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));

uint32_t n_stream_cur;
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ class llama_kv_cache_unified : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_kv_cache_unified specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}

void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}

void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}
Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory-hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class llama_memory_hybrid : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

//
// llama_memory_hybrid specific API
Expand Down
8 changes: 6 additions & 2 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
return size_s_bytes;
}

void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;

Expand Down Expand Up @@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
state_write_data(io, cell_ranges);
}

void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));

Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class llama_memory_recurrent : public llama_memory_i {

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;

uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
Expand Down
4 changes: 2 additions & 2 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ struct llama_memory_i {
// state write/read
//

virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};

using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
Expand Down
Loading
Loading