From 22fd5bd078dbda3f60068a1289e596676ab436e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:33:30 +0300 Subject: [PATCH 1/5] graph : reuse hybrid graphs --- src/llama-graph.cpp | 41 ++++++++++++++++++++++++++++++++++--- src/llama-graph.h | 10 +++++++-- src/llama-memory-hybrid.cpp | 2 +- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b199e9462..18b2413b5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + return res; } // @@ -1914,7 +1949,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f6..25e50238f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -364,22 +364,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e0..a1b45e4a3 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } From cc23af915e2741fe300a3b75c999761e3e25806c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:36:17 +0300 Subject: [PATCH 2/5] graph : reuse recurrent graphs --- src/llama-graph.cpp | 15 +++++++++++++++ src/llama-graph.h | 2 ++ 2 files changed, 17 insertions(+) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 18b2413b5..35768dabd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= s_copy->ne[0] == mctx->get_n_rs(); + + res &= s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + + return res; +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); diff --git a/src/llama-graph.h b/src/llama-graph.h index 25e50238f..944d129c3 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * s_copy; // I32 [n_rs] // views of s_copy, computed once per graph From b1cf2eb640f87e4244c6a6e004433bef9a1d205a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:44:41 +0300 Subject: [PATCH 3/5] graph : fix reuse check for recurrent inputs --- src/llama-graph.cpp | 11 ++++++++++- src/llama-graph.h | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 35768dabd..9ea0d5958 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + return res; } @@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + return res; } @@ -1893,6 +1899,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1961,7 +1970,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 944d129c3..caba9779b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -234,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i { ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { From b1865c914f290b96955f127f9db48edb80b238da Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:57:35 +0300 Subject: [PATCH 4/5] memory : move the recurrent state into the memory context --- src/llama-graph.cpp | 13 ++++++++----- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 ++++++++++------- src/llama-memory-recurrent.h | 6 ++++-- 4 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9ea0d5958..12ae019d9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : + mctx(mctx), + head(mctx->get_head()), + rs_z(mctx->get_rs_z()) { +} + void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -263,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= head == mctx->get_head(); - res &= rs_z == mctx->get_rs_z(); + res &= this->head == mctx->get_head(); + res &= this->rs_z == mctx->get_rs_z(); return res; } @@ -1899,9 +1905,6 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); - inp->head = mctx_cur->get_head(); - inp->rs_z = mctx_cur->get_rs_z(); - return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index caba9779b..44192c66a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} + llm_graph_input_rs(const llama_memory_recurrent_context * mctx); virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // used in view offsets, need to match for valid graph reuse - uint32_t head; - int32_t rs_z; + // need to match for valid graph reuse + const uint32_t head; + const int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 276e1697d..71d426e6f 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1092,12 +1092,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), + n_rs(mem->size), head(0), rs_z(0), size(mem->size) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), + n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { +} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1138,19 +1141,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return is_full ? mem->size : mem->n; + return n_rs; } uint32_t llama_memory_recurrent_context::get_head() const { - return is_full ? 0 : mem->head; + return head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return is_full ? 0 : mem->rs_z; + return rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return mem->size; + return size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1162,5 +1165,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + return mem->cells[i + head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d739..a2b19904f 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,8 +175,10 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: - // TODO: extract all the state like `head` and `n` here // - const bool is_full = false; + const uint32_t n_rs = 0; + const uint32_t head = 0; + const int32_t rs_z = -1; + const uint32_t size = 0; }; From df46214a5f80c9b509f41062f4c17e7205ce3b1e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 19:41:10 +0300 Subject: [PATCH 5/5] Revert "memory : move the recurrent state into the memory context" This reverts commit 00f115fe810815d4a22a6dee0acc346131e970e1. --- src/llama-graph.cpp | 13 +++++-------- src/llama-graph.h | 8 ++++---- src/llama-memory-recurrent.cpp | 17 +++++++---------- src/llama-memory-recurrent.h | 6 ++---- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 12ae019d9..9ea0d5958 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,12 +235,6 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : - mctx(mctx), - head(mctx->get_head()), - rs_z(mctx->get_rs_z()) { -} - void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -269,8 +263,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; - res &= this->head == mctx->get_head(); - res &= this->rs_z == mctx->get_rs_z(); + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); return res; } @@ -1905,6 +1899,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 44192c66a..caba9779b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_rs(const llama_memory_recurrent_context * mctx); + llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {} virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -235,9 +235,9 @@ class llm_graph_input_rs : public llm_graph_input_i { const llama_memory_recurrent_context * mctx; - // need to match for valid graph reuse - const uint32_t head; - const int32_t rs_z; + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 71d426e6f..276e1697d 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1092,15 +1092,12 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context( - llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), - n_rs(mem->size), head(0), rs_z(0), size(mem->size) { + llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { } llama_memory_recurrent_context::llama_memory_recurrent_context( llama_memory_recurrent * mem, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)), - n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) { -} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} llama_memory_recurrent_context::~llama_memory_recurrent_context() = default; @@ -1141,19 +1138,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const { } uint32_t llama_memory_recurrent_context::get_n_rs() const { - return n_rs; + return is_full ? mem->size : mem->n; } uint32_t llama_memory_recurrent_context::get_head() const { - return head; + return is_full ? 0 : mem->head; } int32_t llama_memory_recurrent_context::get_rs_z() const { - return rs_z; + return is_full ? 0 : mem->rs_z; } uint32_t llama_memory_recurrent_context::get_size() const { - return size; + return mem->size; } ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const { @@ -1165,5 +1162,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + head].src0; + return mem->cells[i + mem->head].src0; } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index a2b19904f..47f01d739 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -175,10 +175,8 @@ class llama_memory_recurrent_context : public llama_memory_context_i { // // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here // - const uint32_t n_rs = 0; - const uint32_t head = 0; - const int32_t rs_z = -1; - const uint32_t size = 0; + const bool is_full = false; };