Skip to content

Commit b730706

Browse files
authored
kv-cache : support layer reuse (ggml-org#15504)
* kv-cache : support layer reuse ggml-ci * cont : update comments [no ci]
1 parent c9a24fb commit b730706

12 files changed

+203
-136
lines changed

src/llama-hparams.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
153153

154154
GGML_ABORT("fatal error");
155155
}
156+
157+
bool llama_hparams::has_kv(uint32_t il) const {
158+
if (n_layer_kv_from_start >= 0) {
159+
if (il < (uint32_t) n_layer_kv_from_start) {
160+
return true;
161+
}
162+
163+
return false;
164+
}
165+
166+
// by default, all layers have kv
167+
return true;
168+
}
169+
170+
uint32_t llama_hparams::n_layer_kv() const {
171+
uint32_t res = 0;
172+
173+
for (uint32_t il = 0; il < n_layer; ++il) {
174+
if (has_kv(il)) {
175+
res++;
176+
}
177+
}
178+
179+
return res;
180+
}

src/llama-hparams.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct llama_hparams {
4141
uint32_t n_embd;
4242
uint32_t n_embd_features = 0;
4343
uint32_t n_layer;
44+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
4445
uint32_t n_rot;
4546
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4647
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
221222
uint32_t n_pos_per_embd() const;
222223

223224
bool is_swa(uint32_t il) const;
225+
226+
bool has_kv(uint32_t il) const;
227+
228+
// number of layers for which has_kv() returns true
229+
uint32_t n_layer_kv() const;
224230
};
225231

226232
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-kv-cache-iswa.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
2222
uint32_t kv_size,
2323
uint32_t n_seq_max,
2424
uint32_t n_ubatch,
25-
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
26-
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
27-
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
25+
uint32_t n_pad,
26+
const layer_filter_cb & filter,
27+
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
28+
29+
// chain filters
30+
const layer_filter_cb filter_base = [&](int32_t il) {
31+
if (filter && !filter(il)) {
32+
return false;
33+
}
34+
35+
return !model.hparams.is_swa(il);
36+
};
37+
38+
const layer_filter_cb filter_swa = [&](int32_t il) {
39+
if (filter && !filter(il)) {
40+
return false;
41+
}
42+
43+
return model.hparams.is_swa(il);
44+
};
2845

2946
const uint32_t size_base = kv_size;
3047

@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
4158
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
4259

4360
kv_base = std::make_unique<llama_kv_cache>(
44-
model, std::move(filter_base), type_k, type_v,
61+
model, type_k, type_v,
4562
v_trans, offload, unified, size_base, n_seq_max, n_pad,
46-
0, LLAMA_SWA_TYPE_NONE);
63+
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
4764

4865
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
4966

5067
kv_swa = std::make_unique<llama_kv_cache>(
51-
model, std::move(filter_swa), type_k, type_v,
68+
model, type_k, type_v,
5269
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
53-
hparams.n_swa, hparams.swa_type);
70+
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
5471
}
5572

5673
void llama_kv_cache_iswa::clear(bool data) {

src/llama-kv-cache-iswa.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ class llama_kv_cache_iswa : public llama_memory_i {
2020
bool v_trans,
2121
bool offload,
2222
bool swa_full,
23-
bool ,
23+
bool unified,
2424
uint32_t kv_size,
2525
uint32_t n_seq_max,
2626
uint32_t n_ubatch,
27-
uint32_t n_pad);
27+
uint32_t n_pad,
28+
const layer_filter_cb & filter,
29+
const layer_reuse_cb & reuse);
2830

2931
~llama_kv_cache_iswa() = default;
3032

src/llama-kv-cache.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,33 @@
1717
//
1818

1919
llama_kv_cache::llama_kv_cache(
20-
const llama_model & model,
21-
layer_filter_cb && filter,
22-
ggml_type type_k,
23-
ggml_type type_v,
24-
bool v_trans,
25-
bool offload,
26-
bool unified,
27-
uint32_t kv_size,
28-
uint32_t n_seq_max,
29-
uint32_t n_pad,
30-
uint32_t n_swa,
31-
llama_swa_type swa_type) :
20+
const llama_model & model,
21+
ggml_type type_k,
22+
ggml_type type_v,
23+
bool v_trans,
24+
bool offload,
25+
bool unified,
26+
uint32_t kv_size,
27+
uint32_t n_seq_max,
28+
uint32_t n_pad,
29+
uint32_t n_swa,
30+
llama_swa_type swa_type,
31+
const layer_filter_cb & filter,
32+
const layer_reuse_cb & reuse) :
3233
model(model), hparams(model.hparams), v_trans(v_trans),
3334
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
3435

3536
GGML_ASSERT(kv_size % n_pad == 0);
3637

37-
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
38-
auto n_layer_cache = hparams.n_layer;
39-
if (model.arch == LLM_ARCH_GEMMA3N) {
40-
n_layer_cache = 20;
41-
}
42-
if (model.arch == LLM_ARCH_GLM4_MOE) {
43-
// GLM-4.5: Only process up to last layer, skip final NextN layer
44-
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
45-
}
38+
const uint32_t n_layer_kv = hparams.n_layer_kv();
4639

4740
// create a context for each buffer type
4841
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
4942
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
5043
auto it = ctx_map.find(buft);
5144
if (it == ctx_map.end()) {
5245
ggml_init_params params = {
53-
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
46+
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
5447
/*.mem_buffer =*/ NULL,
5548
/*.no_alloc =*/ true,
5649
};
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
9790
__func__, hparams.n_embd_v_gqa_max());
9891
}
9992

100-
for (uint32_t il = 0; il < n_layer_cache; il++) {
93+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
94+
if (!hparams.has_kv(il)) {
95+
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
96+
continue;
97+
}
98+
10199
if (filter && !filter(il)) {
102-
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
100+
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
103101
continue;
104102
}
105103

@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
147145
layers.push_back({ il, k, v, k_stream, v_stream, });
148146
}
149147

150-
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
151-
if (model.arch == LLM_ARCH_GEMMA3N) {
152-
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
148+
if (reuse) {
149+
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
153150

154-
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
155-
if (filter && !filter(il)) {
156-
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
151+
for (uint32_t il = 0; il < hparams.n_layer; il++) {
152+
const int32_t il_reuse = reuse(il);
153+
154+
if (il_reuse < 0) {
155+
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
157156
continue;
158157
}
159158

160-
const bool is_swa = hparams.is_swa(il);
161-
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
159+
if (filter && !filter(il)) {
160+
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
161+
continue;
162+
}
162163

163164
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
165+
164166
map_layer_ids[il] = map_layer_ids[il_reuse];
165167

166-
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
168+
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
167169
}
168170
}
169171

src/llama-kv-cache.h

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
2121
public:
2222
static uint32_t get_padding(const llama_cparams & cparams);
2323

24-
// this callback is used to filter out layers that should not be included in the cache
25-
using layer_filter_cb = std::function<bool(int32_t il)>;
26-
2724
struct stream_copy_info {
2825
bool empty() const {
2926
assert(ssrc.size() == sdst.size());
@@ -82,18 +79,19 @@ class llama_kv_cache : public llama_memory_i {
8279
using slot_info_vec_t = std::vector<slot_info>;
8380

8481
llama_kv_cache(
85-
const llama_model & model,
86-
layer_filter_cb && filter,
87-
ggml_type type_k,
88-
ggml_type type_v,
89-
bool v_trans,
90-
bool offload,
91-
bool unified,
92-
uint32_t kv_size,
93-
uint32_t n_seq_max,
94-
uint32_t n_pad,
95-
uint32_t n_swa,
96-
llama_swa_type swa_type);
82+
const llama_model & model,
83+
ggml_type type_k,
84+
ggml_type type_v,
85+
bool v_trans,
86+
bool offload,
87+
bool unified,
88+
uint32_t kv_size,
89+
uint32_t n_seq_max,
90+
uint32_t n_pad,
91+
uint32_t n_swa,
92+
llama_swa_type swa_type,
93+
const layer_filter_cb & filter,
94+
const layer_reuse_cb & reuse);
9795

9896
~llama_kv_cache() = default;
9997

src/llama-memory-hybrid.cpp

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,29 @@
99
//
1010

1111
llama_memory_hybrid::llama_memory_hybrid(
12-
const llama_model & model,
13-
/* attn */
14-
ggml_type type_k,
15-
ggml_type type_v,
16-
bool v_trans,
17-
uint32_t kv_size,
18-
uint32_t n_pad,
19-
uint32_t n_swa,
20-
llama_swa_type swa_type,
21-
/* recurrent */
22-
ggml_type type_r,
23-
ggml_type type_s,
24-
uint32_t rs_size,
25-
/* common */
26-
uint32_t n_seq_max,
27-
bool offload,
28-
bool unified,
29-
/* layer filters */
30-
layer_filter_cb && filter_attn,
31-
layer_filter_cb && filter_recr) :
12+
const llama_model & model,
13+
/* attn */
14+
ggml_type type_k,
15+
ggml_type type_v,
16+
bool v_trans,
17+
uint32_t kv_size,
18+
uint32_t n_pad,
19+
uint32_t n_swa,
20+
llama_swa_type swa_type,
21+
/* recurrent */
22+
ggml_type type_r,
23+
ggml_type type_s,
24+
uint32_t rs_size,
25+
/* common */
26+
uint32_t n_seq_max,
27+
bool offload,
28+
bool unified,
29+
/* layer filters */
30+
const layer_filter_cb & filter_attn,
31+
const layer_filter_cb & filter_recr) :
3232
hparams(model.hparams),
3333
mem_attn(new llama_kv_cache(
3434
model,
35-
filter_attn == nullptr ?
36-
[&](int32_t il) { return !hparams.is_recurrent(il); }
37-
: filter_attn,
3835
type_k,
3936
type_v,
4037
v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
4441
n_seq_max,
4542
n_pad,
4643
n_swa,
47-
swa_type
44+
swa_type,
45+
filter_attn == nullptr ?
46+
[&](int32_t il) { return !hparams.is_recurrent(il); }
47+
: filter_attn,
48+
nullptr
4849
)),
4950
mem_recr(new llama_memory_recurrent(
5051
model,
51-
filter_recr == nullptr ?
52-
[&](int32_t il) { return hparams.is_recurrent(il); }
53-
: filter_recr,
5452
type_r,
5553
type_s,
5654
offload,
5755
rs_size,
58-
n_seq_max
56+
n_seq_max,
57+
filter_recr == nullptr ?
58+
[&](int32_t il) { return hparams.is_recurrent(il); }
59+
: filter_recr
5960
)) {}
6061

6162
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {

src/llama-memory-hybrid.h

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,27 @@
1818

1919
class llama_memory_hybrid : public llama_memory_i {
2020
public:
21-
22-
// this callback is used to filter out layers that should not be included in the cache
23-
using layer_filter_cb = std::function<bool(int32_t il)>;
24-
2521
llama_memory_hybrid(
2622
const llama_model & model,
2723
/* attn */
28-
ggml_type type_k,
29-
ggml_type type_v,
30-
bool v_trans,
31-
uint32_t kv_size,
32-
uint32_t n_pad,
33-
uint32_t n_swa,
34-
llama_swa_type swa_type,
35-
/* recurrent */
36-
ggml_type type_r,
37-
ggml_type type_s,
38-
uint32_t rs_size,
39-
/* common */
40-
uint32_t n_seq_max,
41-
bool offload,
42-
bool unified,
43-
/* layer filters */
44-
layer_filter_cb && filter_attn = nullptr,
45-
layer_filter_cb && filter_recr = nullptr);
24+
ggml_type type_k,
25+
ggml_type type_v,
26+
bool v_trans,
27+
uint32_t kv_size,
28+
uint32_t n_pad,
29+
uint32_t n_swa,
30+
llama_swa_type swa_type,
31+
/* recurrent */
32+
ggml_type type_r,
33+
ggml_type type_s,
34+
uint32_t rs_size,
35+
/* common */
36+
uint32_t n_seq_max,
37+
bool offload,
38+
bool unified,
39+
/* layer filters */
40+
const layer_filter_cb & filter_attn = nullptr,
41+
const layer_filter_cb & filter_recr = nullptr);
4642

4743
~llama_memory_hybrid() = default;
4844

0 commit comments

Comments
 (0)