@@ -1018,16 +1018,33 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10181018
10191019 const int32_t ikv = map_layer_ids.at (il);
10201020
1021- auto * k = layers[ikv].k ;
1021+ ggml_tensor * k = layers[ikv].k ;
1022+
1023+ const int64_t n_embd_head = k_cur->ne [0 ];
1024+ const int64_t n_head = k_cur->ne [1 ];
1025+ const int64_t n_tokens = k_cur->ne [2 ];
1026+
1027+ const int64_t n_embd_gqa = n_embd_head*n_head;
10221028
1023- const int64_t n_tokens = k_cur->ne [2 ];
1029+ // we can merge dims 0 and 1
1030+ // TODO: add ggml helper function for this?
1031+ GGML_ASSERT (ggml_row_size (k_cur->type , n_embd_head) == k_cur->nb [1 ]);
10241032
1025- k_cur = ggml_reshape_2d (ctx, k_cur, k-> ne [ 0 ], n_tokens );
1033+ k_cur = ggml_view_2d (ctx, k_cur, n_embd_gqa, n_tokens, k_cur-> nb [ 2 ], 0 );
10261034
1027- if (k->ne [2 ] > 1 ) {
1028- k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
1035+ const int64_t n_stream = k->ne [2 ];
1036+
1037+ if (n_stream > 1 ) {
1038+ const int64_t kv_size = get_size ();
1039+
1040+ assert (n_embd_gqa == k->ne [0 ]);
1041+ assert (kv_size == k->ne [1 ]);
1042+
1043+ // merge the buffer across all streams because the idxs are global
1044+ k = ggml_reshape_2d (ctx, k, n_embd_gqa, kv_size*n_stream);
10291045 }
10301046
1047+ // store the current K values into the cache
10311048 return ggml_set_rows (ctx, k, k_cur, k_idxs);
10321049}
10331050
@@ -1038,28 +1055,51 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10381055
10391056 auto * v = layers[ikv].v ;
10401057
1041- const int64_t n_embd_v_gqa = v_cur->ne [0 ]*v_cur->ne [1 ];
1042- const int64_t n_tokens = v_cur->ne [2 ];
1058+ const int64_t n_embd_head = v_cur->ne [0 ];
1059+ const int64_t n_head = v_cur->ne [1 ];
1060+ const int64_t n_tokens = v_cur->ne [2 ];
1061+
1062+ const int64_t n_embd_gqa = n_embd_head*n_head;
10431063
1044- v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1064+ // we can merge dims 0 and 1
1065+ GGML_ASSERT (ggml_row_size (v_cur->type , n_embd_head) == v_cur->nb [1 ]);
10451066
1067+ const int64_t n_stream = v->ne [2 ];
1068+
1069+ // take this branch when FA is enabled (the V cache is not transposed)
10461070 if (!v_trans) {
1047- if (v->ne [2 ] > 1 ) {
1048- v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1071+ v_cur = ggml_view_2d (ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb [2 ], 0 );
1072+
1073+ if (n_stream > 1 ) {
1074+ const int64_t kv_size = get_size ();
1075+
1076+ assert (n_embd_gqa == v->ne [0 ]);
1077+ assert (kv_size == v->ne [1 ]);
1078+
1079+ // merge the buffer across all streams because the idxs are global
1080+ v = ggml_reshape_2d (ctx, v, n_embd_gqa, kv_size*n_stream);
10491081 }
10501082
10511083 return ggml_set_rows (ctx, v, v_cur, v_idxs);
10521084 }
10531085
1086+ if (ggml_row_size (v_cur->type , n_embd_gqa) == v_cur->nb [2 ]) {
1087+ // we can merge dims 0, 1 and 2
1088+ v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_gqa, n_tokens);
1089+ } else {
1090+ // otherwise -> make a copy to get contiguous data
1091+ v_cur = ggml_cont_2d (ctx, v_cur, n_embd_gqa, n_tokens);
1092+ }
1093+
10541094 // [TAG_V_CACHE_VARIABLE]
1055- if (n_embd_v_gqa < v->ne [0 ]) {
1056- v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa , 0 , 0 , 0 );
1095+ if (n_embd_gqa < v->ne [0 ]) {
1096+ v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_gqa , 0 , 0 , 0 );
10571097 }
10581098
1059- // the row becomes a single element
1060- ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , v-> ne [ 0 ]*v-> ne [ 1 ]*v-> ne [ 2 ] );
1099+ // in this branch the v_idxs are constructed in such a way that each row is a single head element
1100+ ggml_tensor * v_view = ggml_reshape_2d (ctx, v, 1 , ggml_nelements (v) );
10611101
1062- v_cur = ggml_reshape_2d (ctx, v_cur, 1 , v_cur-> ne [ 0 ]*v_cur-> ne [ 1 ] );
1102+ v_cur = ggml_reshape_2d (ctx, v_cur, 1 , ggml_nelements ( v_cur) );
10631103
10641104 return ggml_set_rows (ctx, v_view, v_cur, v_idxs);
10651105}
0 commit comments