Skip to content

Commit 12a9751

Browse files
committed
graph : reuse recurrent graphs
1 parent ca00002 commit 12a9751

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/llama-graph.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,21 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
251251
}
252252
}
253253

254+
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
255+
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
256+
257+
this->mctx = mctx;
258+
259+
bool res = true;
260+
261+
res &= s_copy->ne[0] == mctx->get_n_rs();
262+
263+
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
264+
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
265+
266+
return res;
267+
}
268+
254269
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
255270
GGML_UNUSED(ubatch);
256271

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ class llm_graph_input_rs : public llm_graph_input_i {
224224

225225
void set_input(const llama_ubatch * ubatch) override;
226226

227+
bool can_reuse(const llm_graph_params & params) override;
228+
227229
ggml_tensor * s_copy; // I32 [n_rs]
228230

229231
// views of s_copy, computed once per graph

0 commit comments

Comments
 (0)