@@ -284,19 +284,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
284284}
285285
286286void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
287- if (self_kq_mask) {
288- mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
289- }
287+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
288+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
289+
290+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
290291}
291292
292293void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
293- if (self_kq_mask) {
294- mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295- }
294+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
295+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
296296
297- if (self_kq_mask_swa) {
298- mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
299- }
297+ mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
298+
299+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
300+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
301+
302+ mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
300303}
301304
302305void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
@@ -337,9 +340,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
337340}
338341
339342void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
340- if (self_kq_mask) {
341- mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
342- }
343+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
344+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
345+
346+ mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
343347
344348 const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
345349
@@ -354,7 +358,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
354358 }
355359}
356360
357- void llm_graph_input_one::set_input (const llama_ubatch *) {
361+ void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
362+ GGML_UNUSED (ubatch);
358363 GGML_ASSERT (one && ggml_nelements (one) == 1 );
359364 float f_one = 1 .0f ;
360365 ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -1001,6 +1006,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10011006
10021007 const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
10031008
1009+ inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1010+ inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1011+
10041012 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10051013 // cb(inp->self_kq_mask, "KQ_mask", -1);
10061014 ggml_set_input (inp->self_kq_mask );
@@ -1202,8 +1210,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12021210
12031211 const auto n_kv = mctx_cur->get_n_kv ();
12041212
1213+ inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1214+ inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1215+
12051216 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1206- // cb(inp->self_kq_mask, "KQ_mask", -1);
12071217 ggml_set_input (inp->self_kq_mask );
12081218
12091219 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1234,8 +1244,11 @@ ggml_tensor * llm_graph_context::build_attn(
12341244
12351245 // store to KV cache
12361246 {
1237- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1238- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1247+ const auto & k_idxs = inp->get_k_idxs ();
1248+ const auto & v_idxs = inp->get_v_idxs ();
1249+
1250+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1251+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
12391252 }
12401253
12411254 const auto & kq_mask = inp->get_kq_mask ();
@@ -1294,11 +1307,15 @@ ggml_tensor * llm_graph_context::build_attn(
12941307
12951308 // optionally store to KV cache
12961309 if (k_cur) {
1297- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1310+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1311+
1312+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
12981313 }
12991314
13001315 if (v_cur) {
1301- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1316+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1317+
1318+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
13021319 }
13031320
13041321 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1402,8 +1419,11 @@ ggml_tensor * llm_graph_context::build_attn(
14021419
14031420 // store to KV cache
14041421 {
1405- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1406- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1422+ const auto & k_idxs = inp->get_k_idxs ();
1423+ const auto & v_idxs = inp->get_v_idxs ();
1424+
1425+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1426+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
14071427 }
14081428
14091429 const auto & kq_mask = inp->get_kq_mask ();
@@ -1438,8 +1458,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14381458 {
14391459 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14401460
1461+ inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1462+ inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1463+
14411464 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1442- // cb(inp->self_kq_mask, "KQ_mask", -1);
14431465 ggml_set_input (inp->self_kq_mask );
14441466
14451467 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1450,8 +1472,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14501472
14511473 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
14521474
1475+ inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1476+ inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1477+
14531478 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1454- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14551479 ggml_set_input (inp->self_kq_mask_swa );
14561480
14571481 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments