Skip to content

Commit c699abc

Browse files
committed
llama : add param to control SWA cache size
ggml-ci
1 parent 84742ef commit c699abc

File tree

11 files changed

+63
-27
lines changed

11 files changed

+63
-27
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14451445
params.n_keep = value;
14461446
}
14471447
));
1448+
add_opt(common_arg(
1449+
{"--swa-full"},
1450+
string_format("use full-size SWA cache (default: %s)\n"
1451+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
1452+
[](common_params & params) {
1453+
params.swa_full = true;
1454+
}
1455+
));
14481456
add_opt(common_arg(
14491457
{"--no-context-shift"},
14501458
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11331133
cparams.flash_attn = params.flash_attn;
11341134
cparams.no_perf = params.no_perf;
11351135
cparams.op_offload = !params.no_op_offload;
1136+
cparams.swa_full = params.swa_full;
11361137

11371138
if (params.reranking) {
11381139
cparams.embeddings = true;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ struct common_params {
323323
bool flash_attn = false; // flash attention
324324
bool no_perf = false; // disable performance metrics
325325
bool ctx_shift = true; // context shift on inifinite text generation
326+
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
326327

327328
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
328329
bool use_mmap = true; // use mmap for faster loads

include/llama.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,11 @@ extern "C" {
361361

362362
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
363363
bool embeddings; // if true, extract embeddings (together with logits)
364-
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
365-
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
366-
bool no_perf; // whether to measure performance timings
367-
bool op_offload; // whether to offload host tensor operations to device
364+
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
365+
bool flash_attn; // use flash attention [EXPERIMENTAL]
366+
bool no_perf; // measure performance timings
367+
bool op_offload; // offload host tensor operations to device
368+
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
368369
};
369370

370371
// model quantization parameters

src/llama-context.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@ llama_context::llama_context(
177177
// init the memory module
178178
if (!hparams.vocab_only) {
179179
llama_memory_params params_mem = {
180-
/*.type_k =*/ params.type_k,
181-
/*.type_v =*/ params.type_v,
180+
/*.type_k =*/ params.type_k,
181+
/*.type_v =*/ params.type_v,
182+
/*.swa_full =*/ params.swa_full,
182183
};
183184

184185
memory.reset(model.create_memory(params_mem, cparams));
@@ -2092,6 +2093,7 @@ llama_context_params llama_context_default_params() {
20922093
/*.flash_attn =*/ false,
20932094
/*.no_perf =*/ true,
20942095
/*.op_offload =*/ true,
2096+
/*.swa_full =*/ true,
20952097
};
20962098

20972099
return result;

src/llama-kv-cache.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,27 +1656,38 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16561656
bool v_trans,
16571657
bool offload,
16581658
uint32_t kv_size,
1659+
bool swa_full,
16591660
uint32_t n_seq_max,
16601661
uint32_t n_batch,
16611662
uint32_t padding) : hparams(model.hparams) {
16621663
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
16631664
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
16641665

1665-
const uint32_t kv_size_base = kv_size;
1666-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1666+
const uint32_t size_base = kv_size;
16671667

1668-
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
1668+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1669+
1670+
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1671+
if (swa_full) {
1672+
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
1673+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
1674+
1675+
size_swa = size_base;
1676+
do_prune = false;
1677+
}
1678+
1679+
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
16691680

16701681
kv_base = std::make_unique<llama_kv_cache_unified>(
16711682
model, std::move(filter_base), type_k, type_v,
1672-
v_trans, offload, kv_size_base, padding,
1683+
v_trans, offload, size_base, padding,
16731684
0, LLAMA_SWA_TYPE_NONE);
16741685

1675-
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, kv_size_swa);
1686+
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
16761687

16771688
kv_swa = std::make_unique<llama_kv_cache_unified>(
16781689
model, std::move(filter_swa), type_k, type_v,
1679-
v_trans, offload, kv_size_swa, padding,
1690+
v_trans, offload, size_swa, padding,
16801691
hparams.n_swa, hparams.swa_type);
16811692
}
16821693

@@ -1733,8 +1744,11 @@ void llama_kv_cache_unified_iswa::commit() {
17331744
kv_swa ->commit();
17341745

17351746
// slide the attention window, forgetting/pruning old tokens that are outside the window
1736-
for (const auto & [seq_id, entry] : pending.pos) {
1737-
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1747+
if (do_prune) {
1748+
for (const auto & [seq_id, entry] : pending.pos) {
1749+
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
1750+
}
1751+
17381752
}
17391753

17401754
pending.clear();
@@ -1762,17 +1776,19 @@ void llama_kv_cache_unified_iswa::set_full() {
17621776
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
17631777
pending.clear();
17641778

1765-
for (int i = 0; i < batch.n_tokens; ++i) {
1766-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1767-
const llama_seq_id seq_id = batch.seq_id[i][s];
1768-
const llama_pos pos = batch.pos[i];
1779+
if (do_prune) {
1780+
for (int i = 0; i < batch.n_tokens; ++i) {
1781+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1782+
const llama_seq_id seq_id = batch.seq_id[i][s];
1783+
const llama_pos pos = batch.pos[i];
17691784

1770-
if (pending.pos.find(seq_id) == pending.pos.end()) {
1771-
pending.pos[seq_id].pmin = pos;
1772-
pending.pos[seq_id].pmax = pos;
1773-
} else {
1774-
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1775-
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1785+
if (pending.pos.find(seq_id) == pending.pos.end()) {
1786+
pending.pos[seq_id].pmin = pos;
1787+
pending.pos[seq_id].pmax = pos;
1788+
} else {
1789+
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1790+
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
1791+
}
17761792
}
17771793
}
17781794
}

src/llama-kv-cache.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
318318
bool v_trans,
319319
bool offload,
320320
uint32_t kv_size,
321+
bool swa_full,
321322
uint32_t n_seq_max,
322323
uint32_t n_batch,
323324
uint32_t padding);
@@ -380,6 +381,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
380381
private:
381382
const llama_hparams & hparams;
382383

384+
bool do_prune = true;
385+
383386
struct {
384387
struct entry {
385388
llama_pos pmin;
@@ -390,6 +393,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
390393
pos.clear();
391394
}
392395

396+
// used to perform SWA pruning of old tokens
393397
std::unordered_map<llama_seq_id, entry> pos;
394398
} pending;
395399

src/llama-memory.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ struct llama_memory_params {
77
ggml_type type_k;
88
ggml_type type_v;
99

10-
// parameters for other types of memory
11-
// ...
10+
// use full-size SWA cache
11+
bool swa_full;
1212
};
1313

1414
// general concept of LLM memory

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13227,6 +13227,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1322713227
!cparams.flash_attn,
1322813228
cparams.offload_kqv,
1322913229
cparams.n_ctx,
13230+
params.swa_full,
1323013231
cparams.n_seq_max,
1323113232
cparams.n_batch,
1323213233
padding);

tools/llama-bench/llama-bench.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,7 @@ struct cmd_params_instance {
991991
cparams.flash_attn = flash_attn;
992992
cparams.embeddings = embeddings;
993993
cparams.op_offload = !no_op_offload;
994+
cparams.swa_full = false;
994995

995996
return cparams;
996997
}

0 commit comments

Comments
 (0)