Skip to content

Commit 7c5deb0

Browse files
committed
cont : add option to filter layers fromt he KV cache
ggml-ci
1 parent e5bfd55 commit 7c5deb0

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

src/llama-kv-cache.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
2323
}
2424

2525
llama_kv_cache_unified::llama_kv_cache_unified(
26-
const llama_model & model,
27-
ggml_type type_k,
28-
ggml_type type_v,
29-
bool v_trans,
30-
bool offload,
31-
uint32_t kv_size,
32-
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
26+
const llama_model & model,
27+
layer_filter_cb && filter,
28+
ggml_type type_k,
29+
ggml_type type_v,
30+
bool v_trans,
31+
bool offload,
32+
uint32_t kv_size,
33+
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3334
has_shift = false;
3435
can_shift = true;
3536

@@ -73,6 +74,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7374
cells.resize(kv_size);
7475

7576
for (uint32_t il = 0; il < hparams.n_layer; il++) {
77+
if (filter && !filter(il)) {
78+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
79+
continue;
80+
}
81+
7682
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
7783
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
7884

@@ -1482,6 +1488,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
14821488
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
14831489
uint32_t v_trans;
14841490
uint32_t n_layer;
1491+
14851492
io.read_to(&v_trans, sizeof(v_trans));
14861493
io.read_to(&n_layer, sizeof(n_layer));
14871494

src/llama-kv-cache.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,17 @@ class llama_kv_cache_unified : public llama_kv_cache {
9292
public:
9393
static uint32_t get_padding(const llama_cparams & cparams);
9494

95+
using layer_filter_cb = std::function<bool(int32_t il)>;
96+
9597
llama_kv_cache_unified(
96-
const llama_model & model,
97-
ggml_type type_k,
98-
ggml_type type_v,
99-
bool v_trans,
100-
bool offload,
101-
uint32_t kv_size,
102-
uint32_t padding);
98+
const llama_model & model,
99+
layer_filter_cb && filter,
100+
ggml_type type_k,
101+
ggml_type type_v,
102+
bool v_trans,
103+
bool offload,
104+
uint32_t kv_size,
105+
uint32_t padding);
103106

104107
~llama_kv_cache_unified() = default;
105108

@@ -200,7 +203,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
200203
};
201204

202205
struct kv_layer {
203-
uint32_t il; // layer index in the original model
206+
// layer index in the model
207+
// note: can be different from the layer index in the KV cache
208+
uint32_t il;
204209

205210
ggml_tensor * k;
206211
ggml_tensor * v;

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12895,6 +12895,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1289512895

1289612896
res = new llama_kv_cache_unified(
1289712897
*this,
12898+
nullptr,
1289812899
params.type_k,
1289912900
params.type_v,
1290012901
!cparams.flash_attn,

0 commit comments

Comments
 (0)