@@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284+ if (self_kv_idxs) {
285+ mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
286+ }
287+
284288 if (self_kq_mask) {
285289 mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
286290 }
287291}
288292
289293void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294+ if (self_kv_idxs) {
295+ mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
296+ }
297+
298+ if (self_kv_idxs_swa) {
299+ mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa, ubatch);
300+ }
301+
290302 if (self_kq_mask) {
291303 mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
292304 }
@@ -1198,6 +1210,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11981210
11991211 const auto n_kv = mctx_cur->get_n_kv ();
12001212
1213+ inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1214+ ggml_set_input (inp->self_kv_idxs );
1215+
12011216 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
12021217 // cb(inp->self_kq_mask, "KQ_mask", -1);
12031218 ggml_set_input (inp->self_kq_mask );
@@ -1230,8 +1245,10 @@ ggml_tensor * llm_graph_context::build_attn(
12301245
12311246 // store to KV cache
12321247 {
1233- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1234- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1248+ const auto & kv_idxs = inp->get_kv_idxs ();
1249+
1250+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs, il));
1251+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs, il));
12351252 }
12361253
12371254 const auto & kq_mask = inp->get_kq_mask ();
@@ -1290,11 +1307,15 @@ ggml_tensor * llm_graph_context::build_attn(
12901307
12911308 // optionally store to KV cache
12921309 if (k_cur) {
1293- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1310+ const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1311+
1312+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs, il));
12941313 }
12951314
12961315 if (v_cur) {
1297- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1316+ const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1317+
1318+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs, il));
12981319 }
12991320
13001321 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1398,8 +1419,8 @@ ggml_tensor * llm_graph_context::build_attn(
13981419
13991420 // store to KV cache
14001421 {
1401- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1402- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1422+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, nullptr , il));
1423+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, nullptr , il));
14031424 }
14041425
14051426 const auto & kq_mask = inp->get_kq_mask ();
@@ -1434,6 +1455,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14341455 {
14351456 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14361457
1458+ inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1459+ ggml_set_input (inp->self_kv_idxs );
1460+
14371461 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
14381462 // cb(inp->self_kq_mask, "KQ_mask", -1);
14391463 ggml_set_input (inp->self_kq_mask );
@@ -1446,6 +1470,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14461470
14471471 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
14481472
1473+ inp->self_kv_idxs_swa = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1474+ ggml_set_input (inp->self_kv_idxs_swa );
1475+
14491476 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
14501477 // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14511478 ggml_set_input (inp->self_kq_mask_swa );
0 commit comments