@@ -436,8 +436,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
436436}
437437
438438void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
439- inp_attn->set_input (ubatch);
440- inp_rs->set_input (ubatch);
439+ mctx->get_attn ()->set_input_k_idxs (inp_attn->self_k_idxs , ubatch);
440+ mctx->get_attn ()->set_input_v_idxs (inp_attn->self_v_idxs , ubatch);
441+
442+ mctx->get_attn ()->set_input_kq_mask (inp_attn->self_kq_mask , ubatch, cparams.causal_attn );
443+
444+ const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
445+
446+ if (inp_rs->s_copy ) {
447+ GGML_ASSERT (ggml_backend_buffer_is_host (inp_rs->s_copy ->buffer ));
448+ int32_t * data = (int32_t *) inp_rs->s_copy ->data ;
449+
450+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
451+ for (uint32_t i = 0 ; i < n_rs; ++i) {
452+ data[i] = mctx->get_recr ()->s_copy (i);
453+ }
454+ }
455+ }
456+
457+ bool llm_graph_input_mem_hybrid::can_reuse (const llm_graph_params & params) {
458+ const auto * mctx = static_cast <const llama_memory_hybrid_context *>(params.mctx );
459+
460+ this ->mctx = mctx;
461+
462+ bool res = true ;
463+
464+ res &= inp_attn->self_k_idxs ->ne [0 ] == params.ubatch .n_tokens ;
465+ // res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
466+
467+ res &= inp_attn->self_kq_mask ->ne [0 ] == mctx->get_attn ()->get_n_kv ();
468+ res &= inp_attn->self_kq_mask ->ne [1 ] == GGML_PAD (params.ubatch .n_tokens , GGML_KQ_MASK_PAD);
469+
470+ res &= inp_rs->s_copy ->ne [0 ] == mctx->get_recr ()->get_n_rs ();
471+
472+ res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
473+ res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
474+
475+ return res;
441476}
442477
443478//
@@ -1848,7 +1883,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
18481883 auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
18491884 auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
18501885
1851- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
1886+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move (inp_attn), std::move (inp_rs), mctx_cur);
18521887
18531888 return (llm_graph_input_mem_hybrid *) res->add_input (std::move (inp));
18541889}
0 commit comments