@@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
808808 0 );
809809}
810810
811- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
811+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs , int32_t il, const slot_info & sinfo) const {
812812 const int32_t ikv = map_layer_ids.at (il);
813813
814814 auto * k = layers[ikv].k ;
@@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
818818
819819 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
820820
821- if (kv_idxs && supports_set_rows) {
822- return ggml_set_rows (ctx, k, k_cur, kv_idxs );
821+ if (k_idxs && supports_set_rows) {
822+ return ggml_set_rows (ctx, k, k_cur, k_idxs );
823823 }
824824
825825 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
832832 return ggml_cpy (ctx, k_cur, k_view);
833833}
834834
835- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs , int32_t il, const slot_info & sinfo) const {
835+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs , int32_t il, const slot_info & sinfo) const {
836836 const int32_t ikv = map_layer_ids.at (il);
837837
838838 auto * v = layers[ikv].v ;
@@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
842842
843843 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
844844
845- if (kv_idxs && supports_set_rows) {
845+ if (v_idxs && supports_set_rows) {
846846 if (!v_trans) {
847- return ggml_set_rows (ctx, v, v_cur, kv_idxs );
847+ return ggml_set_rows (ctx, v, v_cur, v_idxs );
848848 }
849849
850850 // the row becomes a single element
@@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
859859 // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
860860
861861 // we broadcast the KV indices n_embd_v_gqa times
862- // v [1, n_kv, n_embd_v_gqa]
863- // v_cur [1, n_tokens, n_embd_v_gqa]
864- // kv_idxs [n_tokens, 1, 1]
865- return ggml_set_rows (ctx, v_view, v_cur, kv_idxs );
862+ // v [1, n_kv, n_embd_v_gqa]
863+ // v_cur [1, n_tokens, n_embd_v_gqa]
864+ // v_idxs [n_tokens, 1, 1]
865+ return ggml_set_rows (ctx, v_view, v_cur, v_idxs );
866866 }
867867
868868 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -885,7 +885,42 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
885885 return ggml_cpy (ctx, v_cur, v_view);
886886}
887887
888- void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
888+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
889+ const uint32_t n_tokens = ubatch.n_tokens ;
890+
891+ ggml_tensor * k_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
892+
893+ ggml_set_input (k_idxs);
894+
895+ return k_idxs;
896+ }
897+
898+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
899+ const uint32_t n_tokens = ubatch.n_tokens ;
900+
901+ ggml_tensor * v_idxs = ggml_new_tensor_1d (ctx, GGML_TYPE_I64, n_tokens);
902+
903+ ggml_set_input (v_idxs);
904+
905+ return v_idxs;
906+ }
907+
908+ void llama_kv_cache_unified::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
909+ if (!supports_set_rows) {
910+ return ;
911+ }
912+
913+ const uint32_t n_tokens = ubatch->n_tokens ;
914+
915+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
916+ int64_t * data = (int64_t *) dst->data ;
917+
918+ for (int64_t i = 0 ; i < n_tokens; ++i) {
919+ data[i] = sinfo.idxs [i];
920+ }
921+ }
922+
923+ void llama_kv_cache_unified::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
889924 if (!supports_set_rows) {
890925 return ;
891926 }
@@ -1906,20 +1941,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
19061941 return kv->get_v (ctx, il, n_kv);
19071942}
19081943
1909- ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1910- return kv->cpy_k (ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
1944+ ggml_tensor * llama_kv_cache_unified_context::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1945+ return kv->cpy_k (ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1946+ }
1947+
1948+ ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1949+ return kv->cpy_v (ctx, v_cur, v_idxs, il, sinfos[i_cur]);
19111950}
19121951
1913- ggml_tensor * llama_kv_cache_unified_context::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1914- return kv->cpy_v (ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
1952+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1953+ return kv->build_input_k_idxs (ctx, ubatch);
1954+ }
1955+
1956+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs (ggml_context * ctx, const llama_ubatch & ubatch) const {
1957+ return kv->build_input_v_idxs (ctx, ubatch);
19151958}
19161959
19171960void llama_kv_cache_unified_context::set_input_k_shift (ggml_tensor * dst) const {
19181961 kv->set_input_k_shift (dst);
19191962}
19201963
1921- void llama_kv_cache_unified_context::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1922- kv->set_input_kv_idxs (dst, ubatch, sinfos[i_cur]);
1964+ void llama_kv_cache_unified_context::set_input_k_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1965+ kv->set_input_k_idxs (dst, ubatch, sinfos[i_cur]);
1966+ }
1967+
1968+ void llama_kv_cache_unified_context::set_input_v_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1969+ kv->set_input_v_idxs (dst, ubatch, sinfos[i_cur]);
19231970}
19241971
19251972void llama_kv_cache_unified_context::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
0 commit comments