@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281}
282282
283283void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284- if (self_kv_idxs) {
285- mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
284+ if (self_k_idxs) {
285+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
286+ }
287+
288+ if (self_v_idxs) {
289+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286290 }
287291
288292 if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291295}
292296
293297void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294- if (self_kv_idxs) {
295- mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
298+ if (self_k_idxs) {
299+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300+ }
301+
302+ if (self_v_idxs) {
303+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
304+ }
305+
306+ if (self_k_idxs_swa) {
307+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
296308 }
297309
298- if (self_kv_idxs_swa ) {
299- mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa , ubatch);
310+ if (self_v_idxs_swa ) {
311+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa , ubatch);
300312 }
301313
302314 if (self_kq_mask) {
@@ -1209,8 +1221,8 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12091221 const auto n_kv = mctx_cur->get_n_kv ();
12101222 const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
12111223
1212- inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens );
1213- ggml_set_input ( inp->self_kv_idxs );
1224+ inp->self_k_idxs = mctx_cur-> build_input_k_idxs (ctx0, ubatch );
1225+ inp->self_v_idxs = mctx_cur-> build_input_v_idxs (ctx0, ubatch );
12141226
12151227 inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
12161228 ggml_set_input (inp->self_kq_mask );
@@ -1243,10 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn(
12431255
12441256 // store to KV cache
12451257 {
1246- const auto & kv_idxs = inp->get_kv_idxs ();
1258+ const auto & k_idxs = inp->get_k_idxs ();
1259+ const auto & v_idxs = inp->get_v_idxs ();
12471260
1248- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1249- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1261+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1262+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
12501263 }
12511264
12521265 const auto & kq_mask = inp->get_kq_mask ();
@@ -1299,10 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn(
12991312
13001313 // store to KV cache
13011314 {
1302- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1315+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1316+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
13031317
1304- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1305- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1318+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1319+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
13061320 }
13071321
13081322 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1444,8 +1458,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14441458 {
14451459 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14461460
1447- inp->self_kv_idxs = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1448- ggml_set_input ( inp->self_kv_idxs );
1461+ inp->self_k_idxs = mctx_cur-> get_base ()-> build_input_k_idxs ( ctx0, ubatch );
1462+ inp->self_v_idxs = mctx_cur-> get_base ()-> build_input_v_idxs (ctx0, ubatch );
14491463
14501464 inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14511465 ggml_set_input (inp->self_kq_mask );
@@ -1458,8 +1472,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14581472
14591473 const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
14601474
1461- inp->self_kv_idxs_swa = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1462- ggml_set_input ( inp->self_kv_idxs_swa );
1475+ inp->self_k_idxs_swa = mctx_cur-> get_swa ()-> build_input_k_idxs ( ctx0, ubatch );
1476+ inp->self_v_idxs_swa = mctx_cur-> get_swa ()-> build_input_v_idxs (ctx0, ubatch );
14631477
14641478 inp->self_kq_mask_swa = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14651479 ggml_set_input (inp->self_kq_mask_swa );
0 commit comments