Skip to content

Commit 7a0fe25

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ec22867 commit 7a0fe25

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,12 +1828,13 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
18281828
//
18291829

18301830
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1831-
const llama_model & model,
1832-
ggml_type type_k,
1833-
ggml_type type_v,
1834-
bool offload,
1835-
uint32_t kv_size,
1836-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1831+
const llama_model & model,
1832+
layer_filter_cb && filter,
1833+
ggml_type type_k,
1834+
ggml_type type_v,
1835+
bool offload,
1836+
uint32_t kv_size,
1837+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
18371838
const int32_t n_layer = hparams.n_layer;
18381839

18391840
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1875,6 +1876,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18751876
v_l.reserve(n_layer);
18761877

18771878
for (int i = 0; i < n_layer; i++) {
1879+
if (filter && !filter(i)) {
1880+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
1881+
continue;
1882+
}
1883+
18781884
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
18791885
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
18801886

src/llama-kv-cache.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,13 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
425425
};
426426

427427
llama_kv_cache_recurrent(
428-
const llama_model & model,
429-
ggml_type type_k,
430-
ggml_type type_v,
431-
bool offload,
432-
uint32_t kv_size,
433-
uint32_t n_seq_max);
428+
const llama_model & model,
429+
layer_filter_cb && filter,
430+
ggml_type type_k,
431+
ggml_type type_v,
432+
bool offload,
433+
uint32_t kv_size,
434+
uint32_t n_seq_max);
434435

435436
~llama_kv_cache_recurrent() = default;
436437

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13204,6 +13204,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320413204
{
1320513205
res = new llama_kv_cache_recurrent(
1320613206
*this,
13207+
nullptr,
1320713208
GGML_TYPE_F32,
1320813209
GGML_TYPE_F32,
1320913210
cparams.offload_kqv,

tests/test-memory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ static void test_llama_kv_cache_recurrent_constructor() {
156156
auto model = _make_model(LLM_ARCH_MAMBA);
157157
llama_kv_cache_recurrent cache(
158158
/* model */ *model,
159+
/* filter */ nullptr,
159160
/* type_k */ GGML_TYPE_F32,
160161
/* type_v */ GGML_TYPE_F16,
161162
/* offload */ false,

0 commit comments

Comments
 (0)