@@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
263263 res &= s_copy_main->ne [0 ] == params.ubatch .n_seqs ;
264264 res &= s_copy_extra->ne [0 ] == mctx->get_n_rs () - params.ubatch .n_seqs ;
265265
266+ res &= head == mctx->get_head ();
267+ res &= rs_z == mctx->get_rs_z ();
268+
266269 return res;
267270}
268271
@@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
509512 res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
510513 res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
511514
515+ res &= inp_rs->head == mctx->get_recr ()->get_head ();
516+ res &= inp_rs->rs_z == mctx->get_recr ()->get_rs_z ();
517+
512518 return res;
513519}
514520
@@ -1893,6 +1899,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
18931899 inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
18941900 inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
18951901
1902+ inp->head = mctx_cur->get_head ();
1903+ inp->rs_z = mctx_cur->get_rs_z ();
1904+
18961905 return inp;
18971906}
18981907
@@ -1961,7 +1970,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
19611970llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
19621971 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
19631972
1964- auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1973+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
19651974 auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
19661975
19671976 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments