@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284- if (self_kv_idxs) {
285- mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
284+ if (self_k_idxs) {
285+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
286+ }
287+
288+ if (self_v_idxs) {
289+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286290 }
287291
288292 if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291295}
292296
293297void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294- if (self_kv_idxs) {
295- mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
298+ if (self_k_idxs) {
299+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300+ }
301+
302+ if (self_v_idxs) {
303+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
296304 }
297305
298- if (self_kv_idxs_swa) {
299- mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa, ubatch);
306+ if (self_k_idxs_swa) {
307+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
308+ }
309+
310+ if (self_v_idxs_swa) {
311+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
300312 }
301313
302314 if (self_kq_mask) {
@@ -345,6 +357,14 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
345357}
346358
347359void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
360+ if (self_k_idxs) {
361+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
362+ }
363+
364+ if (self_v_idxs) {
365+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
366+ }
367+
348368 if (self_kq_mask) {
349369 mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
350370 }
@@ -362,7 +382,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
362382 }
363383}
364384
365- void llm_graph_input_one::set_input (const llama_ubatch *) {
385+ void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
386+ GGML_UNUSED (ubatch);
366387 GGML_ASSERT (one && ggml_nelements (one) == 1 );
367388 float f_one = 1 .0f ;
368389 ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -1009,6 +1030,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10091030
10101031 const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
10111032
1033+ inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1034+ inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1035+
10121036 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
10131037 // cb(inp->self_kq_mask, "KQ_mask", -1);
10141038 ggml_set_input (inp->self_kq_mask );
@@ -1210,11 +1234,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12101234
12111235 const auto n_kv = mctx_cur->get_n_kv ();
12121236
1213- inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens );
1214- ggml_set_input ( inp->self_kv_idxs );
1237+ inp->self_k_idxs = mctx_cur-> build_input_k_idxs (ctx0, ubatch );
1238+ inp->self_v_idxs = mctx_cur-> build_input_v_idxs (ctx0, ubatch );
12151239
12161240 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1217- // cb(inp->self_kq_mask, "KQ_mask", -1);
12181241 ggml_set_input (inp->self_kq_mask );
12191242
12201243 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1245,10 +1268,11 @@ ggml_tensor * llm_graph_context::build_attn(
12451268
12461269 // store to KV cache
12471270 {
1248- const auto & kv_idxs = inp->get_kv_idxs ();
1271+ const auto & k_idxs = inp->get_k_idxs ();
1272+ const auto & v_idxs = inp->get_v_idxs ();
12491273
1250- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1251- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1274+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1275+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
12521276 }
12531277
12541278 const auto & kq_mask = inp->get_kq_mask ();
@@ -1307,15 +1331,15 @@ ggml_tensor * llm_graph_context::build_attn(
13071331
13081332 // optionally store to KV cache
13091333 if (k_cur) {
1310- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1334+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
13111335
1312- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1336+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
13131337 }
13141338
13151339 if (v_cur) {
1316- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1340+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
13171341
1318- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1342+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
13191343 }
13201344
13211345 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1419,8 +1443,11 @@ ggml_tensor * llm_graph_context::build_attn(
14191443
14201444 // store to KV cache
14211445 {
1422- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, nullptr , il));
1423- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, nullptr , il));
1446+ const auto & k_idxs = inp->get_k_idxs ();
1447+ const auto & v_idxs = inp->get_v_idxs ();
1448+
1449+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1450+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
14241451 }
14251452
14261453 const auto & kq_mask = inp->get_kq_mask ();
@@ -1455,11 +1482,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14551482 {
14561483 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14571484
1458- inp->self_kv_idxs = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1459- ggml_set_input ( inp->self_kv_idxs );
1485+ inp->self_k_idxs = mctx_cur-> get_base ()-> build_input_k_idxs ( ctx0, ubatch );
1486+ inp->self_v_idxs = mctx_cur-> get_base ()-> build_input_v_idxs (ctx0, ubatch );
14601487
14611488 inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1462- // cb(inp->self_kq_mask, "KQ_mask", -1);
14631489 ggml_set_input (inp->self_kq_mask );
14641490
14651491 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1470,11 +1496,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14701496
14711497 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
14721498
1473- inp->self_kv_idxs_swa = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1474- ggml_set_input ( inp->self_kv_idxs_swa );
1499+ inp->self_k_idxs_swa = mctx_cur-> get_swa ()-> build_input_k_idxs ( ctx0, ubatch );
1500+ inp->self_v_idxs_swa = mctx_cur-> get_swa ()-> build_input_v_idxs (ctx0, ubatch );
14751501
14761502 inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1477- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
14781503 ggml_set_input (inp->self_kq_mask_swa );
14791504
14801505 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments