@@ -76,8 +76,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7676 continue ;
7777 }
7878
79- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
80- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
79+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
80+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
8181
8282 const char * dev_name = " CPU" ;
8383
@@ -1346,7 +1346,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13461346 for (const auto & layer : layers) {
13471347 const uint32_t il = layer.il ;
13481348
1349- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1349+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
13501350
13511351 // Write key type
13521352 const int32_t k_type_i = (int32_t )layer.k ->type ;
@@ -1368,7 +1368,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13681368 for (const auto & layer : layers) {
13691369 const uint32_t il = layer.il ;
13701370
1371- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1371+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13721372
13731373 // Write value type
13741374 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1392,7 +1392,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13921392 for (const auto & layer : layers) {
13931393 const uint32_t il = layer.il ;
13941394
1395- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1395+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13961396
13971397 // Write value type
13981398 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1526,7 +1526,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15261526 for (const auto & layer : layers) {
15271527 const uint32_t il = layer.il ;
15281528
1529- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1529+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
15301530
15311531 // Read type of key
15321532 int32_t k_type_i_ref;
@@ -1556,7 +1556,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15561556 for (const auto & layer : layers) {
15571557 const uint32_t il = layer.il ;
15581558
1559- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1559+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15601560
15611561 // Read type of value
15621562 int32_t v_type_i_ref;
@@ -1586,7 +1586,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15861586 for (const auto & layer : layers) {
15871587 const uint32_t il = layer.il ;
15881588
1589- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1589+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15901590
15911591 // Read type of value
15921592 int32_t v_type_i_ref;
@@ -1881,8 +1881,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18811881 continue ;
18821882 }
18831883
1884- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
1885- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
1884+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
1885+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
18861886
18871887 const char * dev_name = " CPU" ;
18881888
@@ -2586,7 +2586,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25862586 // Iterate and write all the keys first, each row is a cell
25872587 // Get whole range at a time
25882588 for (uint32_t il = 0 ; il < n_layer; ++il) {
2589- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2589+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
25902590
25912591 // Write key type
25922592 const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -2606,7 +2606,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26062606
26072607 if (!v_trans) {
26082608 for (uint32_t il = 0 ; il < n_layer; ++il) {
2609- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2609+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
26102610
26112611 // Write value type
26122612 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2627,7 +2627,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26272627 // When v is transposed, we also need the element size and get the element ranges from each row
26282628 const uint32_t kv_size = size;
26292629 for (uint32_t il = 0 ; il < n_layer; ++il) {
2630- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2630+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
26312631
26322632 // Write value type
26332633 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2774,7 +2774,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27742774
27752775 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
27762776 for (uint32_t il = 0 ; il < n_layer; ++il) {
2777- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2777+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
27782778
27792779 // Read type of key
27802780 int32_t k_type_i_ref;
@@ -2802,7 +2802,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28022802
28032803 if (!v_trans) {
28042804 for (uint32_t il = 0 ; il < n_layer; ++il) {
2805- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2805+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
28062806
28072807 // Read type of value
28082808 int32_t v_type_i_ref;
@@ -2830,7 +2830,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28302830 } else {
28312831 // For each layer, read the values for each cell (transposed)
28322832 for (uint32_t il = 0 ; il < n_layer; ++il) {
2833- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2833+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
28342834
28352835 // Read type of value
28362836 int32_t v_type_i_ref;
0 commit comments