@@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
263263 res &= s_copy_main->ne [0 ] == params.ubatch .n_seqs ;
264264 res &= s_copy_extra->ne [0 ] == mctx->get_n_rs () - params.ubatch .n_seqs ;
265265
266+ res &= head == mctx->get_head ();
267+ res &= rs_z == mctx->get_rs_z ();
268+
266269 return res;
267270}
268271
@@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
509512 res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
510513 res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
511514
515+ res &= inp_rs->head == mctx->get_recr ()->get_head ();
516+ res &= inp_rs->rs_z == mctx->get_recr ()->get_rs_z ();
517+
512518 return res;
513519}
514520
@@ -1894,6 +1900,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
18941900 inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
18951901 inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
18961902
1903+ inp->head = mctx_cur->get_head ();
1904+ inp->rs_z = mctx_cur->get_rs_z ();
1905+
18971906 return inp;
18981907}
18991908
@@ -1962,7 +1971,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
19621971llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
19631972 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
19641973
1965- auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1974+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
19661975 auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
19671976
19681977 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments