@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746746 0 );
747747}
748748
749- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750750 const int32_t ikv = map_layer_ids.at (il);
751751
752752 auto * k = layers[ikv].k ;
753753
754754 const int64_t n_tokens = k_cur->ne [2 ];
755755
756+ if (kv_idxs) {
757+ return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
758+ }
759+
756760 ggml_tensor * k_view = ggml_view_1d (ctx, k,
757761 n_tokens*hparams.n_embd_k_gqa (il),
758762 ggml_row_size (k->type , hparams.n_embd_k_gqa (il))*head_cur);
759763
760764 return ggml_cpy (ctx, k_cur, k_view);
761765}
762766
763- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764768 const int32_t ikv = map_layer_ids.at (il);
765769
766770 auto * v = layers[ikv].v ;
@@ -772,21 +776,48 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772776 ggml_tensor * v_view = nullptr ;
773777
774778 if (!v_trans) {
779+ if (kv_idxs) {
780+ return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781+ }
782+
775783 v_view = ggml_view_1d (ctx, v,
776784 n_tokens*hparams.n_embd_v_gqa (il),
777785 ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
778786 } else {
787+ v_cur = ggml_transpose (ctx, v_cur);
788+
779789 // note: the V cache is transposed when not using flash attention
790+ if (kv_idxs) {
791+ // the row becomes a single element and we repeat the KV indices d_head times
792+ // TODO: this seems not very optimal - can we do something better?
793+ v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794+
795+ v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
796+
797+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
798+
799+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
800+ }
801+
780802 v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
781803 (v->ne [1 ])*ggml_element_size (v),
782804 (head_cur)*ggml_element_size (v));
783-
784- v_cur = ggml_transpose (ctx, v_cur);
785805 }
786806
787807 return ggml_cpy (ctx, v_cur, v_view);
788808}
789809
810+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
811+ const uint32_t n_tokens = ubatch->n_tokens ;
812+
813+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
814+ int64_t * data = (int64_t *) dst->data ;
815+
816+ for (int64_t i = 0 ; i < n_tokens; ++i) {
817+ data[i] = head_cur + i;
818+ }
819+ }
820+
790821void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791822 const uint32_t n_tokens = ubatch->n_tokens ;
792823
@@ -1789,18 +1820,22 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
17891820 return kv->get_v (ctx, il, n_kv);
17901821}
17911822
1792- ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793- return kv->cpy_k (ctx, k_cur, il, head);
1823+ ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1824+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, head);
17941825}
17951826
1796- ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797- return kv->cpy_v (ctx, v_cur, il, head);
1827+ ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1828+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, head);
17981829}
17991830
18001831void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
18011832 kv->set_input_k_shift (dst);
18021833}
18031834
1835+ void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1836+ kv->set_input_kv_idxs (dst, ubatch, head);
1837+ }
1838+
18041839void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
18051840 kv->set_input_kq_mask (dst, ubatch, causal_attn);
18061841}
0 commit comments