@@ -242,15 +242,15 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
242242void llm_graph_input_rs::set_input (const llama_ubatch * ubatch) {
243243 GGML_UNUSED (ubatch);
244244
245- const int64_t n_kv = kv_state-> get_n_kv ();
245+ const int64_t n_rs = mem_state-> get_n_rs ();
246246
247247 if (s_copy) {
248248 GGML_ASSERT (ggml_backend_buffer_is_host (s_copy->buffer ));
249249 int32_t * data = (int32_t *) s_copy->data ;
250250
251251 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
252- for (uint32_t i = 0 ; i < n_kv ; ++i) {
253- data[i] = kv_state ->s_copy (i);
252+ for (uint32_t i = 0 ; i < n_rs ; ++i) {
253+ data[i] = mem_state ->s_copy (i);
254254 }
255255 }
256256}
@@ -406,18 +406,18 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
406406
407407void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
408408 if (self_kq_mask) {
409- kv_state ->get_state_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
409+ mem_state ->get_state_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
410410 }
411411
412- const int64_t n_kv = kv_state ->get_state_recurrent ()->get_n_kv ();
412+ const int64_t n_rs = mem_state ->get_state_recurrent ()->get_n_rs ();
413413
414414 if (s_copy) {
415415 GGML_ASSERT (ggml_backend_buffer_is_host (s_copy->buffer ));
416416 int32_t * data = (int32_t *) s_copy->data ;
417417
418418 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
419- for (uint32_t i = 0 ; i < n_kv ; ++i) {
420- data[i] = kv_state ->get_state_recurrent ()->s_copy (i);
419+ for (uint32_t i = 0 ; i < n_rs ; ++i) {
420+ data[i] = mem_state ->get_state_recurrent ()->s_copy (i);
421421 }
422422 }
423423}
@@ -1050,14 +1050,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10501050}
10511051
10521052llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1053- const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate);
1053+ const auto * mem_state = static_cast <const llama_memory_hybrid_state *>(mstate);
10541054
1055- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state );
1055+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state );
10561056
10571057 {
10581058 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10591059
1060- const auto n_kv = inp->kv_state ->get_state_attn ()->get_n_kv ();
1060+ const auto n_kv = inp->mem_state ->get_state_attn ()->get_n_kv ();
10611061
10621062 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10631063 // cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1067,9 +1067,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10671067 }
10681068
10691069 {
1070- const auto n_kv = kv_state ->get_state_recurrent ()->get_n_kv ();
1070+ const auto n_rs = mem_state ->get_state_recurrent ()->get_n_rs ();
10711071
1072- inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv );
1072+ inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs );
10731073 ggml_set_input (inp->s_copy );
10741074 }
10751075
@@ -1557,9 +1557,9 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
15571557
15581558 auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
15591559
1560- const auto n_kv = kv_state->get_n_kv ();
1560+ const auto n_rs = kv_state->get_n_rs ();
15611561
1562- inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv );
1562+ inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs );
15631563 ggml_set_input (inp->s_copy );
15641564
15651565 return (llm_graph_input_rs *) res->add_input (std::move (inp));
@@ -1574,7 +1574,7 @@ ggml_tensor * llm_graph_context::build_rs(
15741574 bool avoid_copies) const {
15751575 const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate);
15761576
1577- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_kv (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
1577+ return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
15781578}
15791579
15801580ggml_tensor * llm_graph_context::build_rs (
@@ -1586,7 +1586,7 @@ ggml_tensor * llm_graph_context::build_rs(
15861586 bool avoid_copies) const {
15871587 const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate)->get_state_recurrent ();
15881588
1589- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_kv (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
1589+ return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), avoid_copies);
15901590}
15911591
15921592ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
@@ -1600,11 +1600,11 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16001600
16011601 const int64_t n_seqs = ubatch.n_seqs ;
16021602
1603- ggml_tensor * token_shift_all = kv_state->get_k_l (il);
1603+ ggml_tensor * token_shift_all = kv_state->get_r_l (il);
16041604
16051605 ggml_tensor * token_shift = build_rs (
16061606 inp, gf, token_shift_all,
1607- hparams.n_embd_k_s (), n_seqs);
1607+ hparams.n_embd_r (), n_seqs);
16081608
16091609 token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
16101610
@@ -1627,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
16271627 return ggml_cpy (
16281628 ctx0,
16291629 ggml_view_1d (ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0 ),
1630- ggml_view_1d (ctx0, kv_state->get_k_l (il), hparams.n_embd_k_s ()*n_seqs, hparams.n_embd_k_s ()*kv_head*ggml_element_size (kv_state->get_k_l (il)))
1630+ ggml_view_1d (ctx0, kv_state->get_r_l (il), hparams.n_embd_r ()*n_seqs, hparams.n_embd_r ()*kv_head*ggml_element_size (kv_state->get_r_l (il)))
16311631 );
16321632}
16331633
0 commit comments