@@ -75,8 +75,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7575 continue ;
7676 }
7777
78- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
79- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
78+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
79+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
8080
8181 const char * dev_name = " CPU" ;
8282
@@ -1369,7 +1369,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13691369 for (const auto & layer : layers) {
13701370 const uint32_t il = layer.il ;
13711371
1372- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1372+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
13731373
13741374 // Write key type
13751375 const int32_t k_type_i = (int32_t )layer.k ->type ;
@@ -1391,7 +1391,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13911391 for (const auto & layer : layers) {
13921392 const uint32_t il = layer.il ;
13931393
1394- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1394+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
13951395
13961396 // Write value type
13971397 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1415,7 +1415,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
14151415 for (const auto & layer : layers) {
14161416 const uint32_t il = layer.il ;
14171417
1418- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1418+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
14191419
14201420 // Write value type
14211421 const int32_t v_type_i = (int32_t )layer.v ->type ;
@@ -1552,7 +1552,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15521552 for (const auto & layer : layers) {
15531553 const uint32_t il = layer.il ;
15541554
1555- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
1555+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
15561556
15571557 // Read type of key
15581558 int32_t k_type_i_ref;
@@ -1582,7 +1582,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15821582 for (const auto & layer : layers) {
15831583 const uint32_t il = layer.il ;
15841584
1585- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1585+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
15861586
15871587 // Read type of value
15881588 int32_t v_type_i_ref;
@@ -1612,7 +1612,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16121612 for (const auto & layer : layers) {
16131613 const uint32_t il = layer.il ;
16141614
1615- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
1615+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
16161616
16171617 // Read type of value
16181618 int32_t v_type_i_ref;
@@ -1921,8 +1921,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
19211921 continue ;
19221922 }
19231923
1924- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
1925- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
1924+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s (i );
1925+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s (i );
19261926
19271927 const char * dev_name = " CPU" ;
19281928
@@ -2649,7 +2649,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26492649 // Iterate and write all the keys first, each row is a cell
26502650 // Get whole range at a time
26512651 for (uint32_t il = 0 ; il < n_layer; ++il) {
2652- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2652+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
26532653
26542654 // Write key type
26552655 const int32_t k_type_i = (int32_t )k_l[il]->type ;
@@ -2669,7 +2669,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26692669
26702670 if (!v_trans) {
26712671 for (uint32_t il = 0 ; il < n_layer; ++il) {
2672- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2672+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
26732673
26742674 // Write value type
26752675 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2690,7 +2690,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
26902690 // When v is transposed, we also need the element size and get the element ranges from each row
26912691 const uint32_t kv_size = size;
26922692 for (uint32_t il = 0 ; il < n_layer; ++il) {
2693- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2693+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
26942694
26952695 // Write value type
26962696 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -2837,7 +2837,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28372837
28382838 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
28392839 for (uint32_t il = 0 ; il < n_layer; ++il) {
2840- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
2840+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s (il );
28412841
28422842 // Read type of key
28432843 int32_t k_type_i_ref;
@@ -2865,7 +2865,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28652865
28662866 if (!v_trans) {
28672867 for (uint32_t il = 0 ; il < n_layer; ++il) {
2868- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2868+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
28692869
28702870 // Read type of value
28712871 int32_t v_type_i_ref;
@@ -2893,7 +2893,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
28932893 } else {
28942894 // For each layer, read the values for each cell (transposed)
28952895 for (uint32_t il = 0 ; il < n_layer; ++il) {
2896- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
2896+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s (il );
28972897
28982898 // Read type of value
28992899 int32_t v_type_i_ref;
0 commit comments