@@ -235,7 +235,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235 }
236236}
237237
238- void llm_graph_input_s_copy ::set_input (const llama_ubatch * ubatch) {
238+ void llm_graph_input_rs ::set_input (const llama_ubatch * ubatch) {
239239 GGML_UNUSED (ubatch);
240240
241241 const int64_t n_kv = kv_state->get_n_kv ();
@@ -251,6 +251,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
251251 }
252252}
253253
254+ llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent (
255+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
256+ llm_graph_input_rs(kv_state->get_state_recurrent ()) {
257+ }
258+
254259void llm_graph_input_cross_embd::set_input (const llama_ubatch * ubatch) {
255260 GGML_UNUSED (ubatch);
256261
@@ -354,6 +359,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
354359 }
355360}
356361
362+ llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
363+ const llama_hparams & hparams,
364+ const llama_cparams & cparams,
365+ const llama_kv_cache_hybrid_recurrent_state * kv_state) :
366+ llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
367+ }
368+
357369void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
358370 if (self_kq_mask) {
359371 kv_state->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
@@ -955,25 +967,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
955967 return cur;
956968}
957969
958- ggml_tensor * llm_graph_context::build_inp_s_copy (const llama_kv_cache_recurrent_state * kv_state) const {
959- if (kv_state == nullptr ) {
960- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
961- }
962-
963- auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
964-
965- const auto n_kv = kv_state->get_n_kv ();
966-
967- auto & cur = inp->s_copy ;
968-
969- cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
970- ggml_set_input (cur);
971-
972- res->add_input (std::move (inp));
973-
974- return cur;
975- }
976-
977970ggml_tensor * llm_graph_context::build_inp_cross_embd () const {
978971 auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
979972
@@ -1255,9 +1248,7 @@ ggml_tensor * llm_graph_context::build_attn(
12551248 ggml_build_forward_expand (gf, k_cur);
12561249 ggml_build_forward_expand (gf, v_cur);
12571250
1258- // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1259- // encapsulated in inp
1260- const auto * kv_state = inp->kv_state ;
1251+ const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
12611252
12621253 // store to KV cache
12631254 {
@@ -1289,15 +1280,14 @@ ggml_tensor * llm_graph_context::build_attn(
12891280 return cur;
12901281}
12911282
1292- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1293- const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1294-
1295- auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn ());
1283+ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1284+ auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
1285+ hparams, cparams, static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
12961286
12971287 {
12981288 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
12991289
1300- const auto n_kv = kv_state-> get_state_attn () ->get_n_kv ();
1290+ const auto n_kv = inp-> kv_state ->get_n_kv ();
13011291
13021292 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
13031293 // cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1306,7 +1296,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
13061296 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
13071297 }
13081298
1309- return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1299+ return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1300+ }
1301+
1302+ ggml_tensor * llm_graph_context::build_attn (
1303+ llm_graph_input_attn_kv_hybrid_recurrent * inp,
1304+ ggml_cgraph * gf,
1305+ ggml_tensor * wo,
1306+ ggml_tensor * wo_b,
1307+ ggml_tensor * q_cur,
1308+ ggml_tensor * k_cur,
1309+ ggml_tensor * v_cur,
1310+ ggml_tensor * kq_b,
1311+ ggml_tensor * v_mla,
1312+ float kq_scale,
1313+ int il) const {
1314+ // these nodes are added to the graph together so that they are not reordered
1315+ // by doing so, the number of splits in the graph is reduced
1316+ ggml_build_forward_expand (gf, q_cur);
1317+ ggml_build_forward_expand (gf, k_cur);
1318+ ggml_build_forward_expand (gf, v_cur);
1319+
1320+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn ();
1321+
1322+ // store to KV cache
1323+ {
1324+ ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, il));
1325+ ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, il));
1326+ }
1327+
1328+ const auto & kq_mask = inp->get_kq_mask ();
1329+
1330+ ggml_tensor * q = q_cur;
1331+ ggml_tensor * k = kv_state->get_k (ctx0, il);
1332+ ggml_tensor * v = kv_state->get_v (ctx0, il);
1333+
1334+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1335+ cb (cur, " kqv_out" , il);
1336+
1337+ if (wo) {
1338+ cur = build_lora_mm (wo, cur);
1339+ if (arch == LLM_ARCH_GLM4) {
1340+ // GLM4 seems to have numerical issues with half-precision accumulators
1341+ ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1342+ }
1343+ }
1344+
1345+ if (wo_b) {
1346+ cur = ggml_add (ctx0, cur, wo_b);
1347+ }
1348+
1349+ return cur;
13101350}
13111351
13121352llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1448,19 +1488,90 @@ ggml_tensor * llm_graph_context::build_attn(
14481488 return cur;
14491489}
14501490
1451- ggml_tensor * llm_graph_context::build_recurrent_state (
1452- ggml_cgraph * gf,
1453- ggml_tensor * s,
1454- ggml_tensor * state_copy,
1455- int32_t state_size,
1456- int32_t n_seqs,
1457- bool avoid_copies,
1458- const llama_kv_cache_recurrent_state * kv_state) const {
1491+ llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent () const {
1492+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1493+
1494+ auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1495+
1496+ const auto n_kv = kv_state->get_n_kv ();
1497+
1498+ auto & cur = inp->s_copy ;
1499+
1500+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1501+ ggml_set_input (cur);
1502+
1503+ return (llm_graph_input_rs *) res->add_input (std::move (inp));
1504+ }
1505+
1506+ ggml_tensor * llm_graph_context::build_rs (
1507+ llm_graph_input_rs * inp,
1508+ ggml_cgraph * gf,
1509+ ggml_tensor * s,
1510+ int32_t state_size,
1511+ int32_t n_seqs,
1512+ bool avoid_copies) const {
1513+
1514+ const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1515+
1516+ const auto n_kv = kv_state->get_n_kv ();
1517+ const auto kv_head = kv_state->get_head ();
1518+ const auto rs_zero = kv_state->get_rs_z ();
1519+
1520+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_state->get_size ());
1521+
1522+ // Clear a single state which will then be copied to the other cleared states.
1523+ // Note that this is a no-op when the view is zero-sized.
1524+ ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
1525+ ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
1526+
1527+ ggml_tensor * output_states;
14591528
1460- if (kv_state == nullptr ) {
1461- kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1529+ if (!avoid_copies) {
1530+ // copy states
1531+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1532+ // {state_size, kv_size} -> {state_size, n_seqs}
1533+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 ));
1534+ ggml_build_forward_expand (gf, output_states);
1535+ } else {
1536+ // FIXME: make the gathering operation happen before the copy below
1537+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1538+ output_states = states;
14621539 }
14631540
1541+ // copy extra states which won't be changed further (between n_seqs and n_kv)
1542+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp->s_copy , n_kv - n_seqs, n_seqs*inp->s_copy ->nb [0 ]));
1543+ ggml_build_forward_expand (gf,
1544+ ggml_cpy (ctx0,
1545+ states_extra,
1546+ ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1547+
1548+ return output_states;
1549+ }
1550+
1551+ llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent () const {
1552+ auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
1553+ static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate));
1554+
1555+ const auto n_kv = inp->kv_state ->get_n_kv ();
1556+
1557+ auto & cur = inp->s_copy ;
1558+
1559+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1560+ ggml_set_input (cur);
1561+
1562+ return (llm_graph_input_rs_hybrid_recurrent *) res->add_input (std::move (inp));
1563+ }
1564+
1565+ ggml_tensor * llm_graph_context::build_rs (
1566+ llm_graph_input_rs_hybrid_recurrent * inp,
1567+ ggml_cgraph * gf,
1568+ ggml_tensor * s,
1569+ int32_t state_size,
1570+ int32_t n_seqs,
1571+ bool avoid_copies) const {
1572+
1573+ const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent ();
1574+
14641575 const auto n_kv = kv_state->get_n_kv ();
14651576 const auto kv_head = kv_state->get_head ();
14661577 const auto rs_zero = kv_state->get_rs_z ();
@@ -1478,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14781589 // copy states
14791590 // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
14801591 // {state_size, kv_size} -> {state_size, n_seqs}
1481- output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_seqs, 0 ));
1592+ output_states = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_seqs, 0 ));
14821593 ggml_build_forward_expand (gf, output_states);
14831594 } else {
14841595 // FIXME: make the gathering operation happen before the copy below
@@ -1487,7 +1598,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14871598 }
14881599
14891600 // copy extra states which won't be changed further (between n_seqs and n_kv)
1490- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy , n_kv - n_seqs, n_seqs*state_copy ->nb [0 ]));
1601+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, inp-> s_copy , n_kv - n_seqs, n_seqs*inp-> s_copy ->nb [0 ]));
14911602 ggml_build_forward_expand (gf,
14921603 ggml_cpy (ctx0,
14931604 states_extra,
@@ -1497,9 +1608,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14971608}
14981609
14991610ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
1500- ggml_cgraph * gf ,
1501- ggml_tensor * state_copy ,
1502- const llama_ubatch & ubatch,
1611+ llm_graph_input_rs * inp ,
1612+ ggml_cgraph * gf ,
1613+ const llama_ubatch & ubatch,
15031614 int il) const {
15041615 const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
15051616
@@ -1509,8 +1620,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15091620
15101621 ggml_tensor * token_shift_all = kv_state->get_k_l (il);
15111622
1512- ggml_tensor * token_shift = build_recurrent_state (
1513- gf, token_shift_all, state_copy ,
1623+ ggml_tensor * token_shift = build_rs (
1624+ inp, gf, token_shift_all ,
15141625 hparams.n_embd_k_s (), n_seqs);
15151626
15161627 token_shift = ggml_reshape_3d (ctx0, token_shift, hparams.n_embd , token_shift_count, n_seqs);
0 commit comments