@@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
821821 return ggml_set_rows (ctx, v, v_cur, kv_idxs);
822822 }
823823
824- // note: the V cache is transposed when not using flash attention
825- v_cur = ggml_transpose (ctx, v_cur);
826-
827- // the row becomes a single element and we repeat the KV indices d_head times
824+ // the row becomes a single element
828825 ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
829826
830- v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
827+ // note: the V cache is transposed when not using flash attention
828+ v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
831829
832- // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
833- kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
830+ // note: we can be more explicit here at the cost of extra cont
831+ // however, above we take advantage that a row of single element is always contiguous regardless of the row stride
832+ // v_cur = ggml_transpose(ctx, v_cur);
833+ // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
834834
835+ // we broadcast the KV indices n_embd_v_gqa times
836+ // v [1, n_kv, n_embd_v_gqa]
837+ // v_cur [1, n_tokens, n_embd_v_gqa]
838+ // kv_idxs [n_tokens, 1, 1]
835839 return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
836840 }
837841
0 commit comments