@@ -221,15 +221,15 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
221221void llm_graph_input_rs::set_input (const llama_ubatch * ubatch) {
222222 GGML_UNUSED (ubatch);
223223
224- const int64_t n_rs = mem_state ->get_n_rs ();
224+ const int64_t n_rs = mctx ->get_n_rs ();
225225
226226 if (s_copy) {
227227 GGML_ASSERT (ggml_backend_buffer_is_host (s_copy->buffer ));
228228 int32_t * data = (int32_t *) s_copy->data ;
229229
230230 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
231231 for (uint32_t i = 0 ; i < n_rs; ++i) {
232- data[i] = mem_state ->s_copy (i);
232+ data[i] = mctx ->s_copy (i);
233233 }
234234 }
235235}
@@ -338,18 +338,18 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
338338
339339void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
340340 if (self_kq_mask) {
341- mem_state-> get_state_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
341+ mctx-> get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
342342 }
343343
344- const int64_t n_rs = mem_state-> get_state_recr ()->get_n_rs ();
344+ const int64_t n_rs = mctx-> get_recr ()->get_n_rs ();
345345
346346 if (s_copy) {
347347 GGML_ASSERT (ggml_backend_buffer_is_host (s_copy->buffer ));
348348 int32_t * data = (int32_t *) s_copy->data ;
349349
350350 // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
351351 for (uint32_t i = 0 ; i < n_rs; ++i) {
352- data[i] = mem_state-> get_state_recr ()->s_copy (i);
352+ data[i] = mctx-> get_recr ()->s_copy (i);
353353 }
354354 }
355355}
@@ -999,14 +999,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
999999}
10001000
10011001llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
1002- const auto * mem_state = static_cast <const llama_memory_hybrid_state *>(mstate );
1002+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx );
10031003
1004- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state );
1004+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur );
10051005
10061006 {
10071007 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10081008
1009- const auto n_kv = inp->mem_state -> get_state_attn ()->get_n_kv ();
1009+ const auto n_kv = inp->mctx -> get_attn ()->get_n_kv ();
10101010
10111011 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10121012 // cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1016,7 +1016,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10161016 }
10171017
10181018 {
1019- const auto n_rs = mem_state-> get_state_recr ()->get_n_rs ();
1019+ const auto n_rs = mctx_cur-> get_recr ()->get_n_rs ();
10201020
10211021 inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
10221022 ggml_set_input (inp->s_copy );
@@ -1297,12 +1297,12 @@ ggml_tensor * llm_graph_context::build_attn(
12971297
12981298 const bool is_swa = hparams.is_swa (il);
12991299
1300- const auto * kv_state = is_swa ? kv_state_iswa ->get_swa () : kv_state_iswa ->get_base ();
1300+ const auto * mctx_cur = is_swa ? mctx_iswa ->get_swa () : mctx_iswa ->get_base ();
13011301
13021302 // store to KV cache
13031303 {
1304- ggml_build_forward_expand (gf, kv_state ->cpy_k (ctx0, k_cur, il));
1305- ggml_build_forward_expand (gf, kv_state ->cpy_v (ctx0, v_cur, il));
1304+ ggml_build_forward_expand (gf, mctx_cur ->cpy_k (ctx0, k_cur, il));
1305+ ggml_build_forward_expand (gf, mctx_cur ->cpy_v (ctx0, v_cur, il));
13061306 }
13071307
13081308 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1384,18 +1384,97 @@ ggml_tensor * llm_graph_context::build_attn(
13841384 return cur;
13851385}
13861386
1387- ggml_tensor * llm_graph_context::build_recurrent_state (
1388- ggml_cgraph * gf,
1389- ggml_tensor * s,
1390- ggml_tensor * state_copy,
1391- int32_t state_size,
1392- int32_t n_seqs,
1393- bool avoid_copies) const {
1394- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1387+ ggml_tensor * llm_graph_context::build_attn (
1388+ llm_graph_input_mem_hybrid * inp,
1389+ ggml_cgraph * gf,
1390+ ggml_tensor * wo,
1391+ ggml_tensor * wo_b,
1392+ ggml_tensor * q_cur,
1393+ ggml_tensor * k_cur,
1394+ ggml_tensor * v_cur,
1395+ ggml_tensor * kq_b,
1396+ ggml_tensor * v_mla,
1397+ float kq_scale,
1398+ int il) const {
1399+ // these nodes are added to the graph together so that they are not reordered
1400+ // by doing so, the number of splits in the graph is reduced
1401+ ggml_build_forward_expand (gf, q_cur);
1402+ ggml_build_forward_expand (gf, k_cur);
1403+ ggml_build_forward_expand (gf, v_cur);
1404+
1405+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx)->get_attn ();
1406+
1407+ // store to KV cache
1408+ {
1409+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1410+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1411+ }
1412+
1413+ const auto & kq_mask = inp->get_kq_mask ();
1414+
1415+ ggml_tensor * q = q_cur;
1416+ ggml_tensor * k = mctx_cur->get_k (ctx0, il);
1417+ ggml_tensor * v = mctx_cur->get_v (ctx0, il);
1418+
1419+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1420+ cb (cur, " kqv_out" , il);
1421+
1422+ if (wo) {
1423+ cur = build_lora_mm (wo, cur);
1424+ if (arch == LLM_ARCH_GLM4) {
1425+ // GLM4 seems to have numerical issues with half-precision accumulators
1426+ ggml_mul_mat_set_prec (cur, GGML_PREC_F32);
1427+ }
1428+ }
1429+
1430+ if (wo_b) {
1431+ cur = ggml_add (ctx0, cur, wo_b);
1432+ }
1433+
1434+ return cur;
1435+ }
1436+
1437+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1438+ const auto * mctx_cur = static_cast <const llama_kv_cache_unified_iswa_context *>(mctx);
1439+
1440+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1441+
1442+ {
1443+ const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1444+
1445+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1446+ // cb(inp->self_kq_mask, "KQ_mask", -1);
1447+ ggml_set_input (inp->self_kq_mask );
1448+
1449+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1450+ }
13951451
1396- const auto n_kv = kv_state->get_n_kv ();
1397- const auto kv_head = kv_state->get_head ();
1398- const auto rs_zero = kv_state->get_rs_z ();
1452+ {
1453+ GGML_ASSERT (hparams.swa_type != LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified for non-SWA" );
1454+
1455+ const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1456+
1457+ inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1458+ // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1459+ ggml_set_input (inp->self_kq_mask_swa );
1460+
1461+ inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1462+ }
1463+
1464+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input (std::move (inp));
1465+ }
1466+
1467+ ggml_tensor * llm_graph_context::build_rs (
1468+ ggml_cgraph * gf,
1469+ ggml_tensor * s,
1470+ ggml_tensor * state_copy,
1471+ int32_t state_size,
1472+ int32_t n_seqs,
1473+ uint32_t n_kv,
1474+ uint32_t kv_head,
1475+ uint32_t kv_size,
1476+ int32_t rs_zero,
1477+ bool avoid_copies) const {
13991478
14001479 ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_size);
14011480
@@ -1426,11 +1505,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14261505}
14271506
14281507llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
1429- const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate );
1508+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx );
14301509
1431- auto inp = std::make_unique<llm_graph_input_rs>(kv_state );
1510+ auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur );
14321511
1433- const auto n_rs = kv_state ->get_n_rs ();
1512+ const auto n_rs = mctx_cur ->get_n_rs ();
14341513
14351514 inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
14361515 ggml_set_input (inp->s_copy );
@@ -1445,9 +1524,9 @@ ggml_tensor * llm_graph_context::build_rs(
14451524 int32_t state_size,
14461525 int32_t n_seqs,
14471526 bool avoid_copies) const {
1448- const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate );
1527+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx );
14491528
1450- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state ->get_n_rs (), kv_state ->get_head (), kv_state ->get_size (), kv_state ->get_rs_z (), avoid_copies);
1529+ return build_rs (gf, s, inp->s_copy , state_size, n_seqs, mctx_cur ->get_n_rs (), mctx_cur ->get_head (), mctx_cur ->get_size (), mctx_cur ->get_rs_z (), avoid_copies);
14511530}
14521531
14531532ggml_tensor * llm_graph_context::build_rs (
@@ -1457,23 +1536,23 @@ ggml_tensor * llm_graph_context::build_rs(
14571536 int32_t state_size,
14581537 int32_t n_seqs,
14591538 bool avoid_copies) const {
1460- const auto * kv_state = static_cast <const llama_memory_hybrid_state *>(mstate )->get_state_recr ();
1539+ const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx )->get_recr ();
14611540
1462- return build_rs (gf, s, inp->s_copy , state_size, n_seqs, kv_state ->get_n_rs (), kv_state ->get_head (), kv_state ->get_size (), kv_state ->get_rs_z (), avoid_copies);
1541+ return build_rs (gf, s, inp->s_copy , state_size, n_seqs, mctx_cur ->get_n_rs (), mctx_cur ->get_head (), mctx_cur ->get_size (), mctx_cur ->get_rs_z (), avoid_copies);
14631542}
14641543
14651544ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
14661545 llm_graph_input_rs * inp,
14671546 ggml_cgraph * gf,
14681547 const llama_ubatch & ubatch,
14691548 int il) const {
1470- const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate );
1549+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx );
14711550
14721551 const auto token_shift_count = hparams.token_shift_count ;
14731552
14741553 const int64_t n_seqs = ubatch.n_seqs ;
14751554
1476- ggml_tensor * token_shift_all = kv_state ->get_r_l (il);
1555+ ggml_tensor * token_shift_all = mctx_cur ->get_r_l (il);
14771556
14781557 ggml_tensor * token_shift = build_rs (
14791558 inp, gf, token_shift_all,
@@ -1488,7 +1567,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14881567 ggml_tensor * token_shift,
14891568 const llama_ubatch & ubatch,
14901569 int il) const {
1491- const auto * kv_state = static_cast <const llama_memory_recurrent_state *>(mstate );
1570+ const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx );
14921571
14931572 const auto token_shift_count = hparams.token_shift_count ;
14941573 const auto n_embd = hparams.n_embd ;
@@ -1500,7 +1579,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15001579 return ggml_cpy (
15011580 ctx0,
15021581 ggml_view_1d (ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0 ),
1503- ggml_view_1d (ctx0, kv_state ->get_r_l (il), hparams.n_embd_r ()*n_seqs, hparams.n_embd_r ()*kv_head*ggml_element_size (kv_state ->get_r_l (il)))
1582+ ggml_view_1d (ctx0, mctx_cur ->get_r_l (il), hparams.n_embd_r ()*n_seqs, hparams.n_embd_r ()*kv_head*ggml_element_size (mctx_cur ->get_r_l (il)))
15041583 );
15051584}
15061585
0 commit comments