Skip to content

Commit a126bc4

Browse files
committed
graph : fix reuse check for recurrent inputs
1 parent 12a9751 commit a126bc4

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/llama-graph.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
263263
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
264264
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
265265

266+
res &= head == mctx->get_head();
267+
res &= rs_z == mctx->get_rs_z();
268+
266269
return res;
267270
}
268271

@@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
509512
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
510513
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
511514

515+
res &= inp_rs->head == mctx->get_recr()->get_head();
516+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
517+
512518
return res;
513519
}
514520

@@ -1893,6 +1899,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
18931899
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
18941900
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
18951901

1902+
inp->head = mctx_cur->get_head();
1903+
inp->rs_z = mctx_cur->get_rs_z();
1904+
18961905
return inp;
18971906
}
18981907

@@ -1961,7 +1970,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
19611970
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
19621971
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
19631972

1964-
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
1973+
auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
19651974
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
19661975

19671976
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);

src/llama-graph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i {
234234
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
235235

236236
const llama_memory_recurrent_context * mctx;
237+
238+
// used in view offsets, need to match for valid graph reuse
239+
uint32_t head;
240+
int32_t rs_z;
237241
};
238242

239243
class llm_graph_input_cross_embd : public llm_graph_input_i {

0 commit comments

Comments
 (0)