77#include " llama-kv-cache-unified.h"
88#include " llama-kv-cache-unified-iswa.h"
99#include " llama-kv-cache-recurrent.h"
10+ #include " llama-kv-cache-hybrid-recurrent.h"
1011
1112#include < cassert>
1213#include < cmath>
@@ -957,7 +958,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
957958}
958959
959960ggml_tensor * llm_graph_context::build_inp_s_copy () const {
960- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
961+ const auto * kv_state = get_state_recurrent ( );
961962
962963 auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
963964
@@ -974,7 +975,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
974975}
975976
976977ggml_tensor * llm_graph_context::build_inp_s_mask () const {
977- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
978+ const auto * kv_state = get_state_recurrent ( );
978979
979980 auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
980981
@@ -1028,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10281029}
10291030
10301031ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1031- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1032+ const auto * kv_state = get_state_unified ( );
10321033
10331034 auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
10341035
@@ -1059,6 +1060,30 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10591060 return pos_bias;
10601061}
10611062
1063+ const llama_kv_cache_unified_state * llm_graph_context::get_state_unified () const {
1064+ const auto * umstate = dynamic_cast <const llama_kv_cache_unified_state *>(mstate);
1065+ if (!umstate) {
1066+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1067+ if (hmstate) {
1068+ umstate = hmstate->get_state_attn ();
1069+ }
1070+ }
1071+ GGML_ASSERT (umstate);
1072+ return umstate;
1073+ }
1074+
1075+ const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent () const {
1076+ const auto * rmstate = dynamic_cast <const llama_kv_cache_recurrent_state *>(mstate);
1077+ if (!rmstate) {
1078+ const auto hmstate = dynamic_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1079+ if (hmstate) {
1080+ rmstate = hmstate->get_state_recurrent ();
1081+ }
1082+ }
1083+ GGML_ASSERT (rmstate);
1084+ return rmstate;
1085+ }
1086+
10621087ggml_tensor * llm_graph_context::build_attn_mha (
10631088 ggml_cgraph * gf,
10641089 ggml_tensor * q,
@@ -1234,7 +1259,7 @@ ggml_tensor * llm_graph_context::build_attn(
12341259}
12351260
12361261llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1237- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1262+ const auto * kv_state = get_state_unified ( );
12381263
12391264 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
12401265
@@ -1271,7 +1296,7 @@ ggml_tensor * llm_graph_context::build_attn(
12711296 ggml_build_forward_expand (gf, k_cur);
12721297 ggml_build_forward_expand (gf, v_cur);
12731298
1274- const auto * kv_state = static_cast < const llama_kv_cache_unified_state *>(mstate );
1299+ const auto * kv_state = get_state_unified ( );
12751300
12761301 // store to KV cache
12771302 {
@@ -1449,7 +1474,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
14491474 ggml_tensor * state_mask,
14501475 int32_t n_state,
14511476 int32_t n_seqs) const {
1452- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1477+ const auto * kv_state = get_state_recurrent ( );
14531478
14541479 const auto n_kv = kv_state->get_n_kv ();
14551480 const auto kv_head = kv_state->get_head ();
@@ -1481,7 +1506,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14811506 ggml_tensor * state_mask,
14821507 const llama_ubatch & ubatch,
14831508 int il) const {
1484- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1509+ const auto * kv_state = get_state_recurrent ( );
14851510
14861511 const auto token_shift_count = hparams.token_shift_count ;
14871512
@@ -1502,7 +1527,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15021527 ggml_tensor * token_shift,
15031528 const llama_ubatch & ubatch,
15041529 int il) const {
1505- const auto * kv_state = static_cast < const llama_kv_cache_recurrent_state *>(mstate );
1530+ const auto * kv_state = get_state_recurrent ( );
15061531
15071532 const auto token_shift_count = hparams.token_shift_count ;
15081533 const auto n_embd = hparams.n_embd ;
0 commit comments