@@ -30,13 +30,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030 bool offload,
3131 uint32_t kv_size,
3232 uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
33- const int32_t n_layer = hparams.n_layer ;
34-
3533 has_shift = false ;
3634 can_shift = true ;
3735
3836 LLAMA_LOG_INFO (" %s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n " ,
39- __func__, kv_size, ggml_type_name (type_k), ggml_type_name (type_v), n_layer, can_shift, padding);
37+ __func__, kv_size, ggml_type_name (type_k), ggml_type_name (type_v), hparams. n_layer , can_shift, padding);
4038
4139 GGML_ASSERT (kv_size % padding == 0 && " kv_size must be a multiple of padding" );
4240
@@ -49,7 +47,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
4947 auto it = ctx_map.find (buft);
5048 if (it == ctx_map.end ()) {
5149 ggml_init_params params = {
52- /* .mem_size =*/ size_t (2u *n_layer*ggml_tensor_overhead ()),
50+ /* .mem_size =*/ size_t (2u *hparams. n_layer *ggml_tensor_overhead ()),
5351 /* .mem_buffer =*/ NULL ,
5452 /* .no_alloc =*/ true ,
5553 };
@@ -73,26 +71,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7371 used = 0 ;
7472
7573 cells.resize (kv_size);
76- layers.resize (n_layer);
77-
78- for (int i = 0 ; i < n_layer; i++) {
79- auto & layer = layers[i];
8074
81- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
82- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
75+ for (uint32_t il = 0 ; il < hparams.n_layer ; il++) {
76+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
77+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
8378
8479 const char * dev_name = " CPU" ;
8580
8681 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type ();
8782
8883 if (offload) {
89- auto * dev = model.dev_layer (i );
84+ auto * dev = model.dev_layer (il );
9085 buft = ggml_backend_dev_buffer_type (dev);
9186
9287 dev_name = ggml_backend_dev_name (dev);
9388 }
9489
95- LLAMA_LOG_DEBUG (" %s: layer %3d: dev = %s\n " , __func__, i , dev_name);
90+ LLAMA_LOG_DEBUG (" %s: layer %3d: dev = %s\n " , __func__, il , dev_name);
9691
9792 ggml_context * ctx = ctx_for_buft (buft);
9893 if (!ctx) {
@@ -104,7 +99,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
10499
105100 // TODO: enable
106101#if 0
107- if (hparams.is_swa(i )) {
102+ if (hparams.is_swa(il )) {
108103 k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, hparams.n_swa);
109104 v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, hparams.n_swa);
110105 } else {
@@ -116,11 +111,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
116111 v = ggml_new_tensor_2d (ctx, type_v, n_embd_v_gqa, kv_size);
117112#endif
118113
119- ggml_format_name (k, " cache_k_l%d" , i );
120- ggml_format_name (v, " cache_v_l%d" , i );
114+ ggml_format_name (k, " cache_k_l%d" , il );
115+ ggml_format_name (v, " cache_v_l%d" , il );
121116
122- layer.k = k;
123- layer.v = v;
117+ layers.push_back ({ il, k, v });
124118 }
125119
126120 // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -565,8 +559,10 @@ uint32_t llama_kv_cache_unified::get_n() const {
565559 return n;
566560}
567561
568- ggml_tensor * llama_kv_cache_unified::get_k (ggml_context * ctx, int32_t il) const {
569- auto * k = layers[il].k ;
562+ ggml_tensor * llama_kv_cache_unified::get_k (ggml_context * ctx, int32_t ikv) const {
563+ auto * k = layers[ikv].k ;
564+
565+ const uint32_t il = layers[ikv].il ;
570566
571567 return ggml_view_3d (ctx, k,
572568 hparams.n_embd_head_k , hparams.n_head_kv (il), n,
@@ -575,8 +571,10 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) cons
575571 0 );
576572}
577573
578- ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t il) const {
579- auto * v = layers[il].v ;
574+ ggml_tensor * llama_kv_cache_unified::get_v (ggml_context * ctx, int32_t ikv) const {
575+ auto * v = layers[ikv].v ;
576+
577+ const uint32_t il = layers[ikv].il ;
580578
581579 if (!v_trans) {
582580 // note: v->nb[1] <= v->nb[2]
@@ -595,8 +593,10 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) cons
595593 0 );
596594}
597595
598- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
599- auto * k = layers[il].k ;
596+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t ikv) const {
597+ auto * k = layers[ikv].k ;
598+
599+ const uint32_t il = layers[ikv].il ;
600600
601601 const int64_t n_tokens = k_cur->ne [2 ];
602602
@@ -607,8 +607,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
607607 return ggml_cpy (ctx, k_cur, k_view);
608608}
609609
610- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
611- auto * v = layers[il].v ;
610+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t ikv) const {
611+ auto * v = layers[ikv].v ;
612+
613+ const uint32_t il = layers[ikv].il ;
612614
613615 const int64_t n_tokens = v_cur->ne [2 ];
614616
@@ -890,8 +892,6 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
890892 ggml_cgraph * gf) const {
891893 auto res = std::make_unique<llm_graph_result>();
892894
893- const auto & n_layer = hparams.n_layer ;
894-
895895 const auto & n_embd_head_k = hparams.n_embd_head_k ;
896896 // const auto & n_embd_head_v = hparams.n_embd_head_v;
897897
@@ -904,8 +904,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
904904 inp->k_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, cparams.n_ctx );
905905 ggml_set_input (inp->k_shift );
906906
907- for (uint32_t il = 0 ; il < n_layer; ++il ) {
908- const auto & layer = layers[il] ;
907+ for (const auto & layer : layers ) {
908+ const uint32_t il = layer. il ;
909909
910910 const int64_t n_head_kv = hparams.n_head_kv (il);
911911 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
@@ -1028,8 +1028,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
10281028 nm++;
10291029 }
10301030
1031- for (uint32_t il = 0 ; il < hparams. n_layer ; ++il ) { // NOLINT
1032- const auto & layer = layers[il] ;
1031+ for (const auto & layer : layers ) {
1032+ const uint32_t il = layer. il ;
10331033
10341034 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
10351035 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
@@ -1084,7 +1084,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
10841084}
10851085
10861086bool llama_kv_cache_unified::defrag_prepare (int32_t n_max_nodes) {
1087- const uint32_t n_layer = hparams. n_layer ;
1087+ const uint32_t n_layer = layers. size () ;
10881088
10891089 const uint32_t n_kv = cell_max ();
10901090 const uint32_t n_used = used;
@@ -1309,7 +1309,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
13091309
13101310void llama_kv_cache_unified::state_write_data (llama_io_write_i & io, const std::vector<std::pair<uint32_t , uint32_t >> & cell_ranges) const {
13111311 const uint32_t v_trans = this ->v_trans ? 1 : 0 ;
1312- const uint32_t n_layer = hparams. n_layer ;
1312+ const uint32_t n_layer = layers. size () ;
13131313
13141314 io.write (&v_trans, sizeof (v_trans));
13151315 io.write (&n_layer, sizeof (n_layer));
@@ -1318,8 +1318,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13181318
13191319 // Iterate and write all the keys first, each row is a cell
13201320 // Get whole range at a time
1321- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1322- const auto & layer = layers[il] ;
1321+ for (const auto & layer : layers ) {
1322+ const uint32_t il = layer. il ;
13231323
13241324 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
13251325
@@ -1340,8 +1340,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13401340 }
13411341
13421342 if (!v_trans) {
1343- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1344- const auto & layer = layers[il] ;
1343+ for (const auto & layer : layers ) {
1344+ const uint32_t il = layer. il ;
13451345
13461346 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
13471347
@@ -1364,8 +1364,8 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13641364 // When v is transposed, we also need the element size and get the element ranges from each row
13651365 const uint32_t kv_size = size;
13661366
1367- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1368- const auto & layer = layers[il] ;
1367+ for (const auto & layer : layers ) {
1368+ const uint32_t il = layer. il ;
13691369
13701370 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
13711371
@@ -1485,8 +1485,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14851485 io.read_to (&v_trans, sizeof (v_trans));
14861486 io.read_to (&n_layer, sizeof (n_layer));
14871487
1488- if (n_layer != hparams. n_layer ) {
1489- LLAMA_LOG_ERROR (" %s: mismatched layer count (%u instead of %u)\n " , __func__, n_layer, hparams. n_layer );
1488+ if (n_layer != layers. size () ) {
1489+ LLAMA_LOG_ERROR (" %s: mismatched layer count (%u instead of %u)\n " , __func__, n_layer, ( uint32_t ) layers. size () );
14901490 return false ;
14911491 }
14921492 if (cell_count > size) {
@@ -1499,8 +1499,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14991499 }
15001500
15011501 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
1502- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1503- const auto & layer = layers[il] ;
1502+ for (const auto & layer : layers ) {
1503+ const uint32_t il = layer. il ;
15041504
15051505 const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
15061506
@@ -1529,8 +1529,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15291529 }
15301530
15311531 if (!this ->v_trans ) {
1532- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1533- const auto & layer = layers[il] ;
1532+ for (const auto & layer : layers ) {
1533+ const uint32_t il = layer. il ;
15341534
15351535 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
15361536
@@ -1559,8 +1559,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15591559 }
15601560 } else {
15611561 // For each layer, read the values for each cell (transposed)
1562- for (uint32_t il = 0 ; il < n_layer; ++il ) {
1563- const auto & layer = layers[il] ;
1562+ for (const auto & layer : layers ) {
1563+ const uint32_t il = layer. il ;
15641564
15651565 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
15661566
0 commit comments