@@ -397,13 +397,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
397397 }
398398}
399399
400- llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
401- const llama_hparams & hparams,
402- const llama_cparams & cparams,
403- const llama_kv_cache_hybrid_recurrent_state * kv_state) :
404- llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
405- }
406-
407400//
408401// llm_graph_context
409402//
@@ -1261,7 +1254,9 @@ ggml_tensor * llm_graph_context::build_attn(
12611254 ggml_build_forward_expand (gf, k_cur);
12621255 ggml_build_forward_expand (gf, v_cur);
12631256
1264- const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1257+ // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1258+ // encapsulated in inp
1259+ const auto * kv_state = inp->kv_state ;
12651260
12661261 // store to KV cache
12671262 {
@@ -1293,10 +1288,10 @@ ggml_tensor * llm_graph_context::build_attn(
12931288 return cur;
12941289}
12951290
1296- llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1291+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
12971292 const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
12981293
1299- auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent >(hparams, cparams, kv_state);
1294+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_state-> get_state_attn () );
13001295
13011296 {
13021297 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
@@ -1310,25 +1305,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
13101305 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
13111306 }
13121307
1313- return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1314- }
1315-
1316- ggml_tensor * llm_graph_context::build_attn (
1317- llm_graph_input_attn_kv_hybrid_recurrent * inp,
1318- ggml_cgraph * gf,
1319- ggml_tensor * wo,
1320- ggml_tensor * wo_b,
1321- ggml_tensor * q_cur,
1322- ggml_tensor * k_cur,
1323- ggml_tensor * v_cur,
1324- ggml_tensor * kq_b,
1325- ggml_tensor * v_mla,
1326- float kq_scale,
1327- int il) const {
1328- return build_attn (
1329- static_cast <llm_graph_input_attn_kv_unified *>(inp),
1330- gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1331- );
1308+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
13321309}
13331310
13341311llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1471,13 +1448,17 @@ ggml_tensor * llm_graph_context::build_attn(
14711448}
14721449
14731450ggml_tensor * llm_graph_context::build_recurrent_state (
1474- ggml_cgraph * gf,
1475- ggml_tensor * s,
1476- ggml_tensor * state_copy,
1477- int32_t state_size,
1478- int32_t n_seqs,
1479- bool avoid_copies) const {
1480- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1451+ ggml_cgraph * gf,
1452+ ggml_tensor * s,
1453+ ggml_tensor * state_copy,
1454+ int32_t state_size,
1455+ int32_t n_seqs,
1456+ bool avoid_copies,
1457+ const llama_kv_cache_recurrent_state * kv_state) const {
1458+
1459+ if (kv_state == nullptr ) {
1460+ kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1461+ }
14811462
14821463 const auto n_kv = kv_state->get_n_kv ();
14831464 const auto kv_head = kv_state->get_head ();
0 commit comments