Skip to content

Commit 52d7627

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 60aab95 commit 52d7627

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,12 +1740,13 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
17401740
//
17411741

17421742
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1743-
const llama_model & model,
1744-
ggml_type type_k,
1745-
ggml_type type_v,
1746-
bool offload,
1747-
uint32_t kv_size,
1748-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1743+
const llama_model & model,
1744+
layer_filter_cb && filter,
1745+
ggml_type type_k,
1746+
ggml_type type_v,
1747+
bool offload,
1748+
uint32_t kv_size,
1749+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
17491750
const int32_t n_layer = hparams.n_layer;
17501751

17511752
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1787,6 +1788,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
17871788
v_l.reserve(n_layer);
17881789

17891790
for (int i = 0; i < n_layer; i++) {
1791+
if (filter && !filter(i)) {
1792+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
1793+
continue;
1794+
}
1795+
17901796
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
17911797
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
17921798

src/llama-kv-cache.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,12 +412,13 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
412412
};
413413

414414
llama_kv_cache_recurrent(
415-
const llama_model & model,
416-
ggml_type type_k,
417-
ggml_type type_v,
418-
bool offload,
419-
uint32_t kv_size,
420-
uint32_t n_seq_max);
415+
const llama_model & model,
416+
layer_filter_cb && filter,
417+
ggml_type type_k,
418+
ggml_type type_v,
419+
bool offload,
420+
uint32_t kv_size,
421+
uint32_t n_seq_max);
421422

422423
~llama_kv_cache_recurrent() = default;
423424

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13208,6 +13208,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320813208
{
1320913209
res = new llama_kv_cache_recurrent(
1321013210
*this,
13211+
nullptr,
1321113212
GGML_TYPE_F32,
1321213213
GGML_TYPE_F32,
1321313214
cparams.offload_kqv,

tests/test-memory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ static void test_llama_kv_cache_recurrent_constructor() {
209209
auto model = _make_model(LLM_ARCH_MAMBA);
210210
llama_kv_cache_recurrent cache(
211211
/* model */ *model,
212+
/* filter */ nullptr,
212213
/* type_k */ GGML_TYPE_F32,
213214
/* type_v */ GGML_TYPE_F16,
214215
/* offload */ false,

0 commit comments

Comments
 (0)