Skip to content

Commit 1a7e23d

Browse files
committed
feat: Allow custom layer filters for hybrid recurrent
This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. #13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 9ba8615 commit 1a7e23d

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

src/llama-kv-cache-hybrid-recurrent.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,30 @@
1010

1111
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
1212
const llama_model & model,
13-
/* attn */
14-
ggml_type attn_type_k,
15-
ggml_type attn_type_v,
16-
bool attn_v_trans,
17-
uint32_t attn_kv_size,
18-
uint32_t attn_n_pad,
19-
uint32_t attn_n_swa,
20-
llama_swa_type attn_swa_type,
21-
/* recurrent */
22-
ggml_type recurrent_type_k,
23-
ggml_type recurrent_type_v,
24-
uint32_t recurrent_kv_size,
25-
/* common */
26-
uint32_t n_seq_max,
27-
bool offload) :
13+
/* attn */
14+
ggml_type attn_type_k,
15+
ggml_type attn_type_v,
16+
bool attn_v_trans,
17+
uint32_t attn_kv_size,
18+
uint32_t attn_n_pad,
19+
uint32_t attn_n_swa,
20+
llama_swa_type attn_swa_type,
21+
/* recurrent */
22+
ggml_type recurrent_type_k,
23+
ggml_type recurrent_type_v,
24+
uint32_t recurrent_kv_size,
25+
/* common */
26+
uint32_t n_seq_max,
27+
bool offload,
28+
/* layer filters */
29+
layer_filter_cb && attn_filter,
30+
layer_filter_cb && recurrent_filter) :
2831
hparams(model.hparams),
2932
kv_attn(new llama_kv_cache_unified(
3033
model,
31-
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
34+
attn_filter == nullptr ?
35+
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
36+
: attn_filter,
3237
attn_type_k,
3338
attn_type_v,
3439
attn_v_trans,
@@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
4146
)),
4247
kv_recurrent(new llama_kv_cache_recurrent(
4348
model,
44-
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
49+
recurrent_filter == nullptr ?
50+
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
51+
: recurrent_filter,
4552
recurrent_type_k,
4653
recurrent_type_v,
4754
offload,

src/llama-kv-cache-hybrid-recurrent.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,30 @@
1919

2020
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
2121
public:
22+
23+
// this callback is used to filter out layers that should not be included in the cache
24+
using layer_filter_cb = std::function<bool(int32_t il)>;
25+
2226
llama_kv_cache_hybrid_recurrent(
2327
const llama_model & model,
24-
/* attn */
25-
ggml_type attn_type_k,
26-
ggml_type attn_type_v,
27-
bool attn_v_trans,
28-
uint32_t attn_kv_size,
29-
uint32_t attn_n_pad,
30-
uint32_t attn_n_swa,
31-
llama_swa_type attn_swa_type,
32-
/* recurrent */
33-
ggml_type recurrent_type_k,
34-
ggml_type recurrent_type_v,
35-
uint32_t recurrent_kv_size,
36-
/* common */
37-
uint32_t n_seq_max,
38-
bool offload);
28+
/* attn */
29+
ggml_type attn_type_k,
30+
ggml_type attn_type_v,
31+
bool attn_v_trans,
32+
uint32_t attn_kv_size,
33+
uint32_t attn_n_pad,
34+
uint32_t attn_n_swa,
35+
llama_swa_type attn_swa_type,
36+
/* recurrent */
37+
ggml_type recurrent_type_k,
38+
ggml_type recurrent_type_v,
39+
uint32_t recurrent_kv_size,
40+
/* common */
41+
uint32_t n_seq_max,
42+
bool offload,
43+
/* layer filters */
44+
layer_filter_cb && attn_filter = nullptr,
45+
layer_filter_cb && recurrent_filter = nullptr);
3946

4047
~llama_kv_cache_hybrid_recurrent() = default;
4148

0 commit comments

Comments
 (0)