@@ -1235,6 +1235,13 @@ struct llama_data_read {
12351235 return true ;
12361236 }
12371237
1238+ // other preprocessor macros for which in-memory precomposition speeds up KV cache loading can be added
1239+ #if defined(GGML_VULKAN_COMPUTE)
1240+ #define COMPOSE_MEM_FIRST 1
1241+ #else
1242+ #define COMPOSE_MEM_FIRST 0
1243+ #endif
1244+
12381245 bool read_kv_cache_data (struct llama_context * ctx, uint32_t cell_count) {
12391246 const struct llama_hparams & hparams = ctx->model .hparams ;
12401247 struct llama_kv_cache & kv_self = ctx->kv_self ;
@@ -1312,6 +1319,9 @@ struct llama_data_read {
13121319 }
13131320 }
13141321 } else {
1322+ #if COMPOSE_MEM_FIRST
1323+ std::vector<uint8_t > tmp_buf;
1324+ #endif
13151325 // For each layer, read the values for each cell (transposed)
13161326 for (uint32_t il = 0 ; il < n_layer; ++il) {
13171327 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
@@ -1343,11 +1353,24 @@ struct llama_data_read {
13431353 }
13441354
13451355 if (cell_count) {
1356+ #if COMPOSE_MEM_FIRST
1357+ const size_t buf_size = kv_self.size * v_size_el * n_embd_v_gqa;
1358+ if (tmp_buf.size () < buf_size) {
1359+ tmp_buf.resize (buf_size, 0 );
1360+ }
1361+ // For each row in the transposed matrix, read the values for the whole cell range
1362+ for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
1363+ const size_t dst_offset = (kv_self.head + j * kv_self.size ) * v_size_el;
1364+ memcpy (tmp_buf.data () + dst_offset, read (cell_count * v_size_el), cell_count * v_size_el); // read(cell_count * v_size_el);
1365+ }
1366+ ggml_backend_tensor_set (kv_self.v_l [il], tmp_buf.data (), 0 , buf_size);
1367+ #else
13461368 // For each row in the transposed matrix, read the values for the whole cell range
13471369 for (uint32_t j = 0 ; j < n_embd_v_gqa; ++j) {
13481370 const size_t dst_offset = (kv_self.head + j * kv_self.size ) * v_size_el;
13491371 ggml_backend_tensor_set (kv_self.v_l [il], read (cell_count * v_size_el), dst_offset, cell_count * v_size_el);
13501372 }
1373+ #endif
13511374 }
13521375 }
13531376 }
0 commit comments