Skip to content

Commit 50b8ad4

Browse files
committed
feat: Support hybrid recurrent cache in llm_graph_context
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent b06b275 commit 50b8ad4

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

src/llama-graph.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "llama-kv-cache-unified.h"
88
#include "llama-kv-cache-unified-iswa.h"
99
#include "llama-kv-cache-recurrent.h"
10+
#include "llama-kv-cache-hybrid-recurrent.h"
1011

1112
#include <cassert>
1213
#include <cmath>
@@ -957,7 +958,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
957958
}
958959

959960
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
960-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
961+
const auto * kv_state = get_state_recurrent();
961962

962963
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
963964

@@ -974,7 +975,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
974975
}
975976

976977
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
977-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
978+
const auto * kv_state = get_state_recurrent();
978979

979980
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
980981

@@ -1028,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10281029
}
10291030

10301031
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1031-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1032+
const auto * kv_state = get_state_unified();
10321033

10331034
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
10341035

@@ -1059,6 +1060,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10591060
return pos_bias;
10601061
}
10611062

1063+
const llama_kv_cache_unified_state * llm_graph_context::get_state_unified() const {
1064+
const auto * umstate = dynamic_cast<const llama_kv_cache_unified_state *>(mstate);
1065+
if (!umstate) {
1066+
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1067+
if (hmstate) {
1068+
umstate = hmstate->get_state_attn();
1069+
}
1070+
}
1071+
GGML_ASSERT(umstate);
1072+
return umstate;
1073+
}
1074+
1075+
const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent() const {
1076+
const auto * rmstate = dynamic_cast<const llama_kv_cache_recurrent_state *>(mstate);
1077+
if (!rmstate) {
1078+
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1079+
if (hmstate) {
1080+
rmstate = hmstate->get_state_recurrent();
1081+
}
1082+
}
1083+
GGML_ASSERT(rmstate);
1084+
return rmstate;
1085+
}
1086+
10621087
ggml_tensor * llm_graph_context::build_attn_mha(
10631088
ggml_cgraph * gf,
10641089
ggml_tensor * q,
@@ -1234,7 +1259,7 @@ ggml_tensor * llm_graph_context::build_attn(
12341259
}
12351260

12361261
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1237-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1262+
const auto * kv_state = get_state_unified();
12381263

12391264
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
12401265

@@ -1271,7 +1296,7 @@ ggml_tensor * llm_graph_context::build_attn(
12711296
ggml_build_forward_expand(gf, k_cur);
12721297
ggml_build_forward_expand(gf, v_cur);
12731298

1274-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1299+
const auto * kv_state = get_state_unified();
12751300

12761301
// store to KV cache
12771302
{
@@ -1449,7 +1474,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14491474
ggml_tensor * state_mask,
14501475
int32_t n_state,
14511476
int32_t n_seqs) const {
1452-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1477+
const auto * kv_state = get_state_recurrent();
14531478

14541479
const auto n_kv = kv_state->get_n_kv();
14551480
const auto kv_head = kv_state->get_head();
@@ -1481,7 +1506,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14811506
ggml_tensor * state_mask,
14821507
const llama_ubatch & ubatch,
14831508
int il) const {
1484-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1509+
const auto * kv_state = get_state_recurrent();
14851510

14861511
const auto token_shift_count = hparams.token_shift_count;
14871512

@@ -1502,7 +1527,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15021527
ggml_tensor * token_shift,
15031528
const llama_ubatch & ubatch,
15041529
int il) const {
1505-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1530+
const auto * kv_state = get_state_recurrent();
15061531

15071532
const auto token_shift_count = hparams.token_shift_count;
15081533
const auto n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ struct llm_graph_context {
531531
// attention
532532
//
533533

534+
const llama_kv_cache_unified_state * get_state_unified() const;
535+
534536
ggml_tensor * build_attn_mha(
535537
ggml_cgraph * gf,
536538
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
@@ -605,6 +607,8 @@ struct llm_graph_context {
605607
// recurrent
606608
//
607609

610+
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
611+
608612
ggml_tensor * build_copy_mask_state(
609613
ggml_cgraph * gf,
610614
ggml_tensor * s,

0 commit comments

Comments
 (0)