@@ -1018,21 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10181018
10191019 const int32_t ikv = map_layer_ids.at (il);
10201020
1021- auto * k = layers[ikv].k ;
1021+ ggml_tensor * k = layers[ikv].k ;
1022+
1023+ const int64_t n_embd_head = k_cur->ne [0 ];
1024+ const int64_t n_head = k_cur->ne [1 ];
1025+ const int64_t n_token = k_cur->ne [2 ];
10221026
1023- const int64_t n_embd_k_gqa = k_cur->ne [0 ]*k_cur->ne [1 ];
1024- const int64_t n_tokens = k_cur->ne [2 ];
1027+ const int64_t n_embd_gqa = n_embd_head*n_head;
10251028
10261029 // we can merge dims 0 and 1
1027- assert (k_cur->nb [0 ]*k_cur->ne [0 ] == k_cur->nb [1 ]);
1030+ // TODO: add ggml helper function for this?
1031+ assert (ggml_row_size (k_cur->type , n_embd_head) == k_cur->nb [1 ]);
1032+
1033+ k_cur = ggml_view_2d (ctx, k_cur, n_embd_gqa, n_token , k_cur->nb [2 ], 0 );
1034+
1035+ const int64_t n_stream = k->ne [2 ];
1036+
1037+ if (n_stream > 1 ) {
1038+ const uint64_t kv_size = get_size ();
10281039
1029- // k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens );
1030- k_cur = ggml_view_2d (ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur-> nb [ 2 ], 0 );
1040+ assert (n_embd_gqa == k->ne [0 ]);
1041+ assert (kv_size == k-> ne [ 1 ] );
10311042
1032- if (k-> ne [ 2 ] > 1 ) {
1033- k = ggml_reshape_2d (ctx, k, k-> ne [ 0 ], k-> ne [ 1 ]*k-> ne [ 2 ] );
1043+ // merge the buffer across all streams because the idxs are global
1044+ k = ggml_reshape_2d (ctx, k, n_embd_gqa, kv_size*n_stream );
10341045 }
10351046
1047+ // store the current K values into the cache
10361048 return ggml_set_rows (ctx, k, k_cur, k_idxs);
10371049}
10381050
@@ -1043,39 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10431055
10441056 auto * v = layers[ikv].v ;
10451057
1046- const int64_t n_embd_v_gqa = v_cur->ne [0 ]*v_cur->ne [1 ];
1047- const int64_t n_tokens = v_cur->ne [2 ];
1058+ const int64_t n_embd_head = v_cur->ne [0 ];
1059+ const int64_t n_head = v_cur->ne [1 ];
1060+ const int64_t n_token = v_cur->ne [2 ];
1061+
1062+ const int64_t n_embd_gqa = n_embd_head*n_head;
10481063
10491064 // we can merge dims 0 and 1
1050- assert (v_cur->nb [ 0 ]*v_cur-> ne [ 0 ] == v_cur->nb [1 ]);
1065+ assert (ggml_row_size ( v_cur->type , n_embd_head) == v_cur->nb [1 ]);
10511066
1052- // v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens) ;
1067+ const int64_t n_stream = v-> ne [ 2 ] ;
10531068
1069+ // take this branch when FA is enabled (the V cache is not transposed)
10541070 if (!v_trans) {
1055- v_cur = ggml_view_2d (ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb [2 ], 0 );
1071+ v_cur = ggml_view_2d (ctx, v_cur, n_embd_gqa, n_token, v_cur->nb [2 ], 0 );
1072+
1073+ if (n_stream > 1 ) {
1074+ const uint64_t kv_size = get_size ();
1075+
1076+ assert (n_embd_gqa == v->ne [0 ]);
1077+ assert (kv_size == v->ne [1 ]);
10561078
1057- if (v-> ne [ 2 ] > 1 ) {
1058- v = ggml_reshape_2d (ctx, v, v-> ne [ 0 ], v-> ne [ 1 ]*v-> ne [ 2 ] );
1079+ // merge the buffer across all streams because the idxs are global
1080+ v = ggml_reshape_2d (ctx, v, n_embd_gqa, kv_size*n_stream );
10591081 }
10601082
10611083 return ggml_set_rows (ctx, v, v_cur, v_idxs);
10621084 }
10631085
1064- if (v_cur->nb [1 ]*v_cur->ne [1 ] != v_cur->nb [2 ]) {
1065- v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1086+ if (ggml_row_size (v_cur->type , n_embd_gqa) == v_cur->nb [2 ]) {
1087+ // we can merge dims 0, 1 and 2
1088+ v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_gqa, n_token);
10661089 } else {
1067- v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1090+ // otherwise -> make a copy to get contiguous data
1091+ v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_token);
10681092 }
10691093
10701094 // [TAG_V_CACHE_VARIABLE]
1071- if (n_embd_v_gqa < v->ne [0 ]) {
1072- v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa , 0 , 0 , 0 );
1095+ if (n_embd_gqa < v->ne [0 ]) {
1096+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_gqa , 0 , 0 , 0 );
10731097 }
10741098
1075- // the row becomes a single element
1076- ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v-> ne [ 0 ]*v-> ne [ 1 ]*v-> ne [ 2 ] );
1099+ // in this branch the v_idxs are constructed in such a way that each row becomes a single element
1100+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , ggml_nelements (v) );
10771101
1078- v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur-> ne [ 0 ]*v_cur-> ne [ 1 ] );
1102+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , ggml_nelements ( v_cur) );
10791103
10801104 return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
10811105}
0 commit comments