@@ -74,8 +74,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7474 continue ;
7575 }
7676
77- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
78- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
77+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
78+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
7979
8080 const char * dev_name = " CPU" ;
8181
@@ -1255,7 +1255,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
12551255 for (const auto & layer : layers) {
12561256 const uint32_t il = layer.il ;
12571257
1258- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1258+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
12591259
12601260 // Write key type
12611261 const int32_t k_type_i = (int32_t )layer.k ->type ;
@@ -1277,7 +1277,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
12771277 for (const auto & layer : layers) {
12781278 const uint32_t il = layer.il ;
12791279
1280- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1280+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
12811281
12821282 // Write value type
12831283 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1301,7 +1301,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13011301 for (const auto & layer : layers) {
13021302 const uint32_t il = layer.il ;
13031303
1304- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1304+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13051305
13061306 // Write value type
13071307 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1438,7 +1438,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14381438 for (const auto & layer : layers) {
14391439 const uint32_t il = layer.il ;
14401440
1441- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1441+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
14421442
14431443 // Read type of key
14441444 int32_t k_type_i_ref;
@@ -1468,7 +1468,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14681468 for (const auto & layer : layers) {
14691469 const uint32_t il = layer.il ;
14701470
1471- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1471+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
14721472
14731473 // Read type of value
14741474 int32_t v_type_i_ref;
@@ -1498,7 +1498,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14981498 for (const auto & layer : layers) {
14991499 const uint32_t il = layer.il ;
15001500
1501- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1501+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15021502
15031503 // Read type of value
15041504 int32_t v_type_i_ref;
@@ -1793,8 +1793,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
17931793 continue ;
17941794 }
17951795
1796- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
1797- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
1796+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
1797+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
17981798
17991799 const char * dev_name = " CPU" ;
18001800
@@ -2498,7 +2498,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
24982498 // Iterate and write all the keys first, each row is a cell
24992499 // Get whole range at a time
25002500 for (uint32_t il = 0 ; il < n_layer; ++il) {
2501- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2501+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
25022502
25032503 // Write key type
25042504 const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -2518,7 +2518,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25182518
25192519 if (!v_trans) {
25202520 for (uint32_t il = 0 ; il < n_layer; ++il) {
2521- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2521+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
25222522
25232523 // Write value type
25242524 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2539,7 +2539,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25392539 // When v is transposed, we also need the element size and get the element ranges from each row
25402540 const uint32_t kv_size = size;
25412541 for (uint32_t il = 0 ; il < n_layer; ++il) {
2542- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2542+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
25432543
25442544 // Write value type
25452545 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2686,7 +2686,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
26862686
26872687 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
26882688 for (uint32_t il = 0 ; il < n_layer; ++il) {
2689- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2689+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
26902690
26912691 // Read type of key
26922692 int32_t k_type_i_ref;
@@ -2714,7 +2714,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27142714
27152715 if (!v_trans) {
27162716 for (uint32_t il = 0 ; il < n_layer; ++il) {
2717- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2717+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
27182718
27192719 // Read type of value
27202720 int32_t v_type_i_ref;
@@ -2742,7 +2742,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27422742 } else {
27432743 // For each layer, read the values for each cell (transposed)
27442744 for (uint32_t il = 0 ; il < n_layer; ++il) {
2745- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2745+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
27462746
27472747 // Read type of value
27482748 int32_t v_type_i_ref;
0 commit comments