Skip to content

Commit ca00002

Browse files
committed
graph : reuse hybrid graphs
1 parent dcca0d3 commit ca00002

File tree

3 files changed

+47
-6
lines changed

3 files changed

+47
-6
lines changed

src/llama-graph.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
458458
}
459459

460460
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
461-
inp_attn->set_input(ubatch);
462-
inp_rs->set_input(ubatch);
461+
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
462+
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
463+
464+
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
465+
466+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
467+
468+
if (inp_rs->s_copy) {
469+
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
470+
int32_t * data = (int32_t *) inp_rs->s_copy->data;
471+
472+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
473+
for (uint32_t i = 0; i < n_rs; ++i) {
474+
data[i] = mctx->get_recr()->s_copy(i);
475+
}
476+
}
477+
}
478+
479+
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
480+
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
481+
482+
this->mctx = mctx;
483+
484+
bool res = true;
485+
486+
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
487+
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
488+
489+
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
490+
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
491+
492+
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
493+
494+
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
495+
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
496+
497+
return res;
463498
}
464499

465500
//
@@ -1914,7 +1949,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
19141949
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
19151950
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
19161951

1917-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1952+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
19181953

19191954
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
19201955
}

src/llama-graph.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,22 +364,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
364364
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
365365
public:
366366
llm_graph_input_mem_hybrid(
367+
const llama_cparams & cparams,
367368
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
368-
std::unique_ptr<llm_graph_input_rs> inp_rs,
369-
const llama_memory_hybrid_context * mctx) :
369+
std::unique_ptr<llm_graph_input_rs> inp_rs,
370+
const llama_memory_hybrid_context * mctx) :
370371
inp_attn(std::move(inp_attn)),
371372
inp_rs(std::move(inp_rs)),
373+
cparams(cparams),
372374
mctx(mctx) { }
373375
virtual ~llm_graph_input_mem_hybrid() = default;
374376

375377
void set_input(const llama_ubatch * ubatch) override;
376378

379+
bool can_reuse(const llm_graph_params & params) override;
380+
377381
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
378382
std::unique_ptr<llm_graph_input_rs> inp_rs;
379383

380384
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
381385
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
382386

387+
const llama_cparams cparams;
388+
383389
const llama_memory_hybrid_context * mctx;
384390
};
385391

src/llama-memory-hybrid.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
222222
ubatches(std::move(ubatches)),
223223
// note: here we copy the ubatches. not sure if this is ideal
224224
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
225-
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
225+
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
226226
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
227227
}
228228

0 commit comments

Comments
 (0)