@@ -777,23 +777,32 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
777777
778778 if (!v_trans) {
779779 if (kv_idxs) {
780- return ggml_set_rows (ctx, v, ggml_reshape_2d (ctx, v_cur, v-> ne [ 0 ], n_tokens) , kv_idxs);
780+ return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781781 }
782782
783783 v_view = ggml_view_1d (ctx, v,
784784 n_tokens*hparams.n_embd_v_gqa (il),
785785 ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
786786 } else {
787+ v_cur = ggml_transpose (ctx, v_cur);
788+
789+ // note: the V cache is transposed when not using flash attention
787790 if (kv_idxs) {
788- GGML_ABORT (" TODO: implement kv_idxs for transposed V cache -- for now use flash attention" );
791+ // the row becomes a single element and we repeat the KV indices d_head times
792+ // TODO: this seems not very optimal - can we do something better?
793+ v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794+
795+ v_cur = ggml_cont (ctx, v_cur);
796+ v_cur = ggml_reshape_3d (ctx, v_cur, 1 , n_tokens, hparams.n_embd_v_gqa (il));
797+
798+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
799+
800+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
789801 }
790802
791- // note: the V cache is transposed when not using flash attention
792803 v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
793804 (v->ne [1 ])*ggml_element_size (v),
794805 (head_cur)*ggml_element_size (v));
795-
796- v_cur = ggml_transpose (ctx, v_cur);
797806 }
798807
799808 return ggml_cpy (ctx, v_cur, v_view);
0 commit comments