@@ -362,11 +362,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
362362
363363void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
364364 if (self_kq_mask) {
365- kv_self->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
365+ kv_self->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn , false );
366+ }
367+ }
368+
369+ void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
370+ if (self_kq_mask) {
371+ kv_self->get_kv_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn , false );
366372 }
367373
368374 if (self_kq_mask_swa) {
369- kv_self->set_input_kq_mask_swa ( self_kq_mask_swa, ubatch, cparams.causal_attn );
375+ kv_self->get_kv_swa ()-> set_input_kq_mask ( self_kq_mask_swa, ubatch, cparams.causal_attn , true );
370376 }
371377}
372378
@@ -416,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
416422 n_layer (hparams.n_layer),
417423 n_rot (hparams.n_rot),
418424 n_ctx (cparams.n_ctx),
419- n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max),
420425 n_head (hparams.n_head()),
421426 n_head_kv (hparams.n_head_kv()),
422427 n_embd_head_k (hparams.n_embd_head_k),
@@ -1231,6 +1236,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12311236 auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12321237
12331238 {
1239+ GGML_ASSERT (hparams.n_swa_pattern == 1 && " Use llama_kv_cache_unified_iswa for SWA" );
1240+ GGML_ASSERT (hparams.n_swa == 0 && " Use llama_kv_cache_unified_iswa for SWA" );
1241+
12341242 const auto n_kv = kv_self->get_n ();
12351243
12361244 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
@@ -1240,10 +1248,79 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12401248 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
12411249 }
12421250
1243- if (hparams. n_swa_pattern > 1 ) {
1244- GGML_ASSERT (hparams. n_swa > 0 );
1251+ return (llm_graph_input_attn_kv_unified *) res-> add_input ( std::move (inp));
1252+ }
12451253
1246- const auto n_kv = kv_self->get_n ();
1254+ ggml_tensor * llm_graph_context::build_attn (
1255+ llm_graph_input_attn_kv_unified * inp,
1256+ ggml_cgraph * gf,
1257+ ggml_tensor * wo,
1258+ ggml_tensor * wo_b,
1259+ ggml_tensor * q_cur,
1260+ ggml_tensor * k_cur,
1261+ ggml_tensor * v_cur,
1262+ ggml_tensor * kq_b,
1263+ ggml_tensor * v_mla,
1264+ float kq_scale,
1265+ int il) const {
1266+ // these nodes are added to the graph together so that they are not reordered
1267+ // by doing so, the number of splits in the graph is reduced
1268+ ggml_build_forward_expand (gf, q_cur);
1269+ ggml_build_forward_expand (gf, k_cur);
1270+ ggml_build_forward_expand (gf, v_cur);
1271+
1272+ const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1273+
1274+ // store to KV cache
1275+ {
1276+ ggml_build_forward_expand (gf, kv_self->cpy_k (ctx0, k_cur, il));
1277+ ggml_build_forward_expand (gf, kv_self->cpy_v (ctx0, v_cur, il));
1278+ }
1279+
1280+ const auto & kq_mask = inp->get_kq_mask ();
1281+
1282+ ggml_tensor * q = q_cur;
1283+ ggml_tensor * k = kv_self->get_k (ctx0, il);
1284+ ggml_tensor * v = kv_self->get_v (ctx0, il);
1285+
1286+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1287+ cb (cur, " kqv_out" , il);
1288+
1289+ if (wo) {
1290+ cur = build_lora_mm (wo, cur);
1291+ }
1292+
1293+ if (wo_b) {
1294+ // cb(cur, "kqv_wo", il);
1295+ }
1296+
1297+ if (wo_b) {
1298+ cur = ggml_add (ctx0, cur, wo_b);
1299+ }
1300+
1301+ return cur;
1302+ }
1303+
1304+ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
1305+ const llama_kv_cache_unified_iswa * kv_self = static_cast <const llama_kv_cache_unified_iswa *>(memory);
1306+
1307+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
1308+
1309+ {
1310+ const auto n_kv = kv_self->get_kv_base ()->get_n ();
1311+
1312+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1313+ // cb(inp->self_kq_mask, "KQ_mask", -1);
1314+ ggml_set_input (inp->self_kq_mask );
1315+
1316+ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1317+ }
1318+
1319+ {
1320+ GGML_ASSERT (hparams.n_swa_pattern > 1 && " Use llama_kv_cache_unified for non-SWA" );
1321+ GGML_ASSERT (hparams.n_swa > 0 && " Use llama_kv_cache_unified for non-SWA" );
1322+
1323+ const auto n_kv = kv_self->get_kv_swa ()->get_n ();
12471324
12481325 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
12491326 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1252,11 +1329,11 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12521329 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
12531330 }
12541331
1255- return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1332+ return (llm_graph_input_attn_kv_unified_iswa *) res->add_input (std::move (inp));
12561333}
12571334
12581335ggml_tensor * llm_graph_context::build_attn (
1259- llm_graph_input_attn_kv_unified * inp,
1336+ llm_graph_input_attn_kv_unified_iswa * inp,
12601337 ggml_cgraph * gf,
12611338 ggml_tensor * wo,
12621339 ggml_tensor * wo_b,
@@ -1273,21 +1350,23 @@ ggml_tensor * llm_graph_context::build_attn(
12731350 ggml_build_forward_expand (gf, k_cur);
12741351 ggml_build_forward_expand (gf, v_cur);
12751352
1276- const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1353+ const bool is_swa = hparams.is_swa (il);
1354+
1355+ const llama_kv_cache_unified_iswa * kv_self = static_cast <const llama_kv_cache_unified_iswa *>(memory);
1356+
1357+ const auto * kv = is_swa ? kv_self->get_kv_swa () : kv_self->get_kv_base ();
12771358
12781359 // store to KV cache
12791360 {
1280- ggml_build_forward_expand (gf, kv_self ->cpy_k (ctx0, k_cur, il));
1281- ggml_build_forward_expand (gf, kv_self ->cpy_v (ctx0, v_cur, il));
1361+ ggml_build_forward_expand (gf, kv ->cpy_k (ctx0, k_cur, il));
1362+ ggml_build_forward_expand (gf, kv ->cpy_v (ctx0, v_cur, il));
12821363 }
12831364
1284- const bool is_swa = hparams.is_swa (il);
1285-
12861365 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
12871366
12881367 ggml_tensor * q = q_cur;
1289- ggml_tensor * k = kv_self ->get_k (ctx0, il);
1290- ggml_tensor * v = kv_self ->get_v (ctx0, il);
1368+ ggml_tensor * k = kv ->get_k (ctx0, il);
1369+ ggml_tensor * v = kv ->get_v (ctx0, il);
12911370
12921371 ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12931372 cb (cur, " kqv_out" , il);
0 commit comments