@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
335335 }
336336}
337337
338+ void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
339+ inp_attn->set_input (ubatch);
340+ inp_rs->set_input (ubatch);
341+ }
342+
338343void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
339344 GGML_UNUSED (ubatch);
340345 GGML_ASSERT (one && ggml_nelements (one) == 1 );
@@ -1147,17 +1152,20 @@ ggml_tensor * llm_graph_context::build_attn(
11471152 return cur;
11481153}
11491154
1150- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (const llama_kv_cache_unified_context * mctx_cur) const {
1151- if (!mctx_cur) {
1152- mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
1153- }
1155+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl (
1156+ ggml_context * ctx0,
1157+ const llama_ubatch & ubatch,
1158+ const llama_hparams & hparams,
1159+ const llama_cparams & cparams,
1160+ const llama_kv_cache_unified_context * mctx_cur) {
11541161
11551162 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
11561163
11571164 {
11581165 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
11591166
11601167 const auto n_kv = mctx_cur->get_n_kv ();
1168+ const auto n_tokens = ubatch.n_tokens ;
11611169
11621170 inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
11631171 inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
11681176 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
11691177 }
11701178
1179+ return inp;
1180+ }
1181+
1182+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1183+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_context *>(mctx);
1184+
1185+ auto inp = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur);
1186+
11711187 return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
11721188}
11731189
@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
13461362 return cur;
13471363}
13481364
1349- llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa (const llama_kv_cache_unified_iswa_context * mctx_cur) const {
1350- if (!mctx_cur) {
1351- mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
1352- }
1365+ // TODO: maybe separate the inner implementation into a separate function
1366+ // like with the non-sliding window equivalent
1367+ // once sliding-window hybrid caches are a thing.
1368+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1369+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
13531370
13541371 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13551372
@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
14171434 return output_states;
14181435}
14191436
1420- llm_graph_input_rs * llm_graph_context::build_rs_inp (const llama_memory_recurrent_context * mctx_cur) const {
1421- if (!mctx_cur) {
1422- mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1423- }
1437+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl (
1438+ ggml_context * ctx0,
1439+ const llama_memory_recurrent_context * mctx_cur) {
14241440
14251441 auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
14261442
@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
14291445 inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
14301446 ggml_set_input (inp->s_copy );
14311447
1448+ return inp;
1449+ }
1450+
1451+ llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1452+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
1453+
1454+ auto inp = build_rs_inp_impl (ctx0, mctx_cur);
1455+
14321456 return (llm_graph_input_rs *) res->add_input (std::move (inp));
14331457}
14341458
@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14861510 );
14871511}
14881512
1513+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1514+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
1515+
1516+ auto inp_rs = build_rs_inp_impl (ctx0, mctx_cur->get_recr ());
1517+ auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
1518+
1519+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
1520+
1521+ return (llm_graph_input_mem_hybrid *) res->add_input (std::move (inp));
1522+ }
1523+
14891524void llm_graph_context::build_pooling (
14901525 ggml_cgraph * gf,
14911526 ggml_tensor * cls,
0 commit comments