66
77#include " llama-kv-cache-unified.h"
88#include " llama-kv-cache-unified-iswa.h"
9- #include " llama-kv-cache-recurrent .h"
10- #include " llama-kv-cache-hybrid -recurrent.h"
9+ #include " llama-memory-hybrid .h"
10+ #include " llama-memory -recurrent.h"
1111
1212#include < cassert>
1313#include < cmath>
@@ -1050,7 +1050,7 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10501050}
10511051
10521052llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1053- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1053+ const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate);
10541054
10551055 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
10561056
@@ -1447,7 +1447,7 @@ ggml_tensor * llm_graph_context::build_attn(
14471447 ggml_build_forward_expand (gf, k_cur);
14481448 ggml_build_forward_expand (gf, v_cur);
14491449
1450- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn ();
1450+ const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate)->get_state_attn ();
14511451
14521452 // store to KV cache
14531453 {
@@ -1553,7 +1553,7 @@ ggml_tensor * llm_graph_context::build_rs(
15531553}
15541554
15551555llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1556- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1556+ const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate);
15571557
15581558 auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
15591559
@@ -1572,7 +1572,7 @@ ggml_tensor * llm_graph_context::build_rs(
15721572 int32_t state_size,
15731573 int32_t n_seqs,
15741574 bool avoid_copies) const {
1575- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1575+ const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate);
15761576
15771577 return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_kv (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
15781578}
@@ -1584,7 +1584,7 @@ ggml_tensor * llm_graph_context::build_rs(
15841584 int32_t state_size,
15851585 int32_t n_seqs,
15861586 bool avoid_copies) const {
1587- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1587+ const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate)->get_state_recurrent ();
15881588
15891589 return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_kv (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
15901590}
@@ -1594,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15941594 ggml_cgraph * gf,
15951595 const llama_ubatch & ubatch,
15961596 int il) const {
1597- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1597+ const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate);
15981598
15991599 const auto token_shift_count = hparams.token_shift_count ;
16001600
@@ -1615,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
16151615 ggml_tensor * token_shift,
16161616 const llama_ubatch & ubatch,
16171617 int il) const {
1618- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1618+ const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate);
16191619
16201620 const auto token_shift_count = hparams.token_shift_count ;
16211621 const auto n_embd = hparams.n_embd ;
0 commit comments