@@ -1757,56 +1757,62 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
17571757
17581758ggml_tensor * llm_graph_context::build_rs (
17591759 ggml_tensor * s,
1760- ggml_tensor * state_copy,
1760+ ggml_tensor * state_copy_main,
1761+ ggml_tensor * state_copy_extra,
17611762 int32_t state_size,
17621763 int32_t n_seqs,
1763- uint32_t n_kv ,
1764- uint32_t kv_head ,
1765- uint32_t kv_size ,
1764+ uint32_t n_rs ,
1765+ uint32_t rs_head ,
1766+ uint32_t rs_size ,
17661767 int32_t rs_zero,
17671768 const llm_graph_get_rows_fn & get_state_rows) const {
17681769
1769- ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_size );
1770+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, rs_size );
17701771
17711772 // Clear a single state which will then be copied to the other cleared states.
17721773 // Note that this is a no-op when the view is zero-sized.
17731774 ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
17741775 ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
17751776
17761777 // copy states
1777- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1778- // {state_size, kv_size } -> {state_size, n_seqs}
1779- ggml_tensor * output_states = get_state_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_seqs, 0 ) );
1778+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1779+ // {state_size, rs_size } -> {state_size, n_seqs}
1780+ ggml_tensor * output_states = get_state_rows (ctx0, states, state_copy_main );
17801781 ggml_build_forward_expand (gf, output_states);
17811782
1782- // copy extra states which won't be changed further (between n_seqs and n_kv )
1783- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy-> nb [ 0 ]) );
1783+ // copy extra states which won't be changed further (between n_seqs and n_rs )
1784+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, state_copy_extra );
17841785 ggml_build_forward_expand (gf,
17851786 ggml_cpy (ctx0,
17861787 states_extra,
1787- ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1788+ ggml_view_1d (ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size (s))));
17881789
17891790 return output_states;
17901791}
17911792
17921793static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl (
17931794 ggml_context * ctx0,
1795+ const llama_ubatch & ubatch,
17941796 const llama_memory_recurrent_context * mctx_cur) {
17951797
17961798 auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
17971799
1798- const auto n_rs = mctx_cur->get_n_rs ();
1800+ const int64_t n_rs = mctx_cur->get_n_rs ();
1801+ const int64_t n_seqs = ubatch.n_seqs ;
17991802
18001803 inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
18011804 ggml_set_input (inp->s_copy );
18021805
1806+ inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
1807+ inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
1808+
18031809 return inp;
18041810}
18051811
18061812llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
18071813 const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
18081814
1809- auto inp = build_rs_inp_impl (ctx0, mctx_cur);
1815+ auto inp = build_rs_inp_impl (ctx0, ubatch, mctx_cur);
18101816
18111817 return (llm_graph_input_rs *) res->add_input (std::move (inp));
18121818}
@@ -1819,7 +1825,9 @@ ggml_tensor * llm_graph_context::build_rs(
18191825 const llm_graph_get_rows_fn & get_state_rows) const {
18201826 const auto * kv_state = inp->mctx ;
18211827
1822- return build_rs (s, inp->s_copy , state_size, n_seqs, kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), get_state_rows);
1828+ return build_rs (s, inp->s_copy_main , inp->s_copy_extra , state_size, n_seqs,
1829+ kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (),
1830+ get_state_rows);
18231831}
18241832
18251833ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
@@ -1866,7 +1874,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
18661874llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
18671875 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
18681876
1869- auto inp_rs = build_rs_inp_impl (ctx0, mctx_cur->get_recr ());
1877+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
18701878 auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
18711879
18721880 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments