@@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
458458}
459459
460460void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
461- inp_attn->set_input (ubatch);
462- inp_rs->set_input (ubatch);
461+ mctx->get_attn ()->set_input_k_idxs (inp_attn->self_k_idxs , ubatch);
462+ mctx->get_attn ()->set_input_v_idxs (inp_attn->self_v_idxs , ubatch);
463+
464+ mctx->get_attn ()->set_input_kq_mask (inp_attn->self_kq_mask , ubatch, cparams.causal_attn );
465+
466+ const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
467+
468+ if (inp_rs->s_copy ) {
469+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_rs->s_copy ->buffer ));
470+ int32_t * data = (int32_t *) inp_rs->s_copy ->data ;
471+
472+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
473+ for (uint32_t i = 0 ; i < n_rs; ++i) {
474+ data[i] = mctx->get_recr ()->s_copy (i);
475+ }
476+ }
477+ }
478+
479+ bool llm_graph_input_mem_hybrid::can_reuse (const llm_graph_params & params) {
480+ const auto * mctx = static_cast <const llama_memory_hybrid_context *>(params.mctx );
481+
482+ this ->mctx = mctx;
483+
484+ bool res = true ;
485+
486+ res &= inp_attn->self_k_idxs ->ne [0 ] == params.ubatch .n_tokens ;
487+ // res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
488+
489+ res &= inp_attn->self_kq_mask ->ne [0 ] == mctx->get_attn ()->get_n_kv ();
490+ res &= inp_attn->self_kq_mask ->ne [1 ] == GGML_PAD (params.ubatch .n_tokens , GGML_KQ_MASK_PAD);
491+
492+ res &= inp_rs->s_copy ->ne [0 ] == mctx->get_recr ()->get_n_rs ();
493+
494+ res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
495+ res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
496+
497+ return res;
463498}
464499
465500//
@@ -1914,7 +1949,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
19141949 auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
19151950 auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
19161951
1917- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
1952+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move (inp_attn), std::move (inp_rs), mctx_cur);
19181953
19191954 return (llm_graph_input_mem_hybrid *) res->add_input (std::move (inp));
19201955}
0 commit comments