@@ -937,17 +937,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
937937 hparams.n_embd_head_v , hparams.n_head_kv (il), n_kv, ns,
938938 ggml_row_size (v->type , hparams.n_embd_head_v ), // v->nb[1]
939939 ggml_row_size (v->type , hparams.n_embd_v_gqa (il)), // v->nb[2]
940- size_virt,
940+ size_virt, // v->nb[3]
941941 size_virt*sinfo.s0 );
942942 }
943943
944944 // note: v->nb[1] > v->nb[2]
945945 return ggml_view_4d (ctx, v,
946946 n_kv, hparams.n_head_kv (il), hparams.n_embd_head_v , ns,
947- ggml_row_size (v->type , v->ne [1 ]*hparams.n_embd_head_v ), // v->nb[1]
948- ggml_row_size (v->type , v->ne [1 ]), // v->nb[2]
949- size_virt,
950- size_virt *sinfo.s0 );
947+ ggml_row_size (v->type , v->ne [1 ]*n_seq_virt* hparams.n_embd_head_v ), // v->nb[1]
948+ ggml_row_size (v->type , v->ne [1 ]*n_seq_virt ), // v->nb[2]
949+ ggml_row_size (v-> type , v-> ne [ 1 ]), // v->nb[3]
950+ ggml_row_size (v-> type , v-> ne [ 1 ] *sinfo.s0 ) );
951951}
952952
953953ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
@@ -961,20 +961,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
961961 k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
962962
963963 if (kv_idxs && supports_set_rows) {
964- const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
965-
966- const uint64_t size_virt = ggml_row_size (k->type , hparams.n_embd_k_gqa (il)*get_size ());
967-
968- ggml_tensor * k_view = ggml_view_3d (ctx, k, k->ne [0 ], k->ne [1 ], ns,
969- ggml_row_size (k->type , k->ne [0 ]),
970- size_virt,
971- size_virt*sinfo.s0 );
972-
973- k_cur = ggml_reshape_3d (ctx, k_cur, k_cur->ne [0 ], k_cur->ne [1 ]/ns, ns);
974-
975- kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
964+ k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
976965
977- return ggml_set_rows (ctx, k_view , k_cur, kv_idxs);
966+ return ggml_set_rows (ctx, k , k_cur, kv_idxs);
978967 }
979968
980969 // TODO: fallback to old ggml_cpy() method for backwards compatibility
@@ -1000,45 +989,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
1000989 v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1001990
1002991 if (kv_idxs && supports_set_rows) {
1003- const uint32_t ns = sinfo.s1 - sinfo.s0 + 1 ;
1004-
1005- const uint64_t size_virt = ggml_row_size (v->type , hparams.n_embd_v_gqa (il)*get_size ());
1006-
1007992 if (!v_trans) {
1008- ggml_tensor * v_view = ggml_view_3d (ctx, v, v->ne [0 ], v->ne [1 ], ns,
1009- ggml_row_size (v->type , v->ne [0 ]),
1010- size_virt,
1011- size_virt*sinfo.s0 );
1012-
1013- v_cur = ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], v_cur->ne [1 ]/ns, ns);
1014-
1015- kv_idxs = ggml_reshape_2d (ctx, kv_idxs, n_tokens/ns, ns);
993+ v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1016994
1017- return ggml_set_rows (ctx, v_view , v_cur, kv_idxs);
995+ return ggml_set_rows (ctx, v , v_cur, kv_idxs);
1018996 }
1019997
1020998 // the row becomes a single element
1021- ggml_tensor * v_view = ggml_view_4d (ctx, v, 1 , v->ne [1 ], v->ne [0 ], ns,
1022- ggml_row_size (v->type , 1 ),
1023- ggml_row_size (v->type , v->ne [1 ]),
1024- size_virt,
1025- size_virt*sinfo.s0 );
999+ ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ]*v->ne [2 ], v->ne [0 ]);
10261000
10271001 // note: the V cache is transposed when not using flash attention
1028- v_cur = ggml_permute (ctx, ggml_reshape_4d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]/ns, ns ), 2 , 0 , 1 , 3 );
1002+ v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
10291003
10301004 // note: we can be more explicit here at the cost of extra cont
10311005 // however, above we take advantage that a row of single element is always contiguous regardless of the row stride
1032- // v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
10331006 // v_cur = ggml_transpose(ctx, v_cur);
1034- // v_cur = ggml_cont_4d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2 ]);
1007+ // v_cur = ggml_cont_3d (ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
10351008
10361009 // we broadcast the KV indices n_embd_v_gqa times
1037- // v [1, n_kv, n_embd_v_gqa, ns]
1038- // v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
1039- // kv_idxs [n_tokens/ns, 1, ns]
1040-
1041- kv_idxs = ggml_reshape_3d (ctx, kv_idxs, n_tokens/ns, 1 , ns);
1010+ // v [1, n_kv*n_seq_virt, n_embd_v_gqa]
1011+ // v_cur [1, n_tokens, n_embd_v_gqa]
1012+ // kv_idxs [n_tokens, 1, 1]
10421013
10431014 return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
10441015 }
@@ -1077,8 +1048,10 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
10771048 int64_t * data = (int64_t *) dst->data ;
10781049
10791050 for (uint32_t s = 0 ; s < sinfo.n_seq_virt (); ++s) {
1051+ const int64_t offs = sinfo.seq_id_virt [s]*get_size ();
1052+
10801053 for (uint32_t i = 0 ; i < sinfo.size (); ++i) {
1081- data[s*sinfo.size () + i] = sinfo.idxs [s][i];
1054+ data[s*sinfo.size () + i] = offs + sinfo.idxs [s][i];
10821055 }
10831056 }
10841057}
0 commit comments