@@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969 continue ;
7070 }
7171
72- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
73- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
74-
7572 const char * dev_name = " CPU" ;
7673
7774 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type ();
@@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
9087 throw std::runtime_error (" failed to create ggml context for kv cache" );
9188 }
9289
93- ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, n_embd_k_gqa *kv_size);
94- ggml_tensor * v = ggml_new_tensor_1d (ctx, type_v, n_embd_v_gqa *kv_size);
90+ ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, hparams. n_embd_k_s () *kv_size);
91+ ggml_tensor * v = ggml_new_tensor_1d (ctx, type_v, hparams. n_embd_v_s () *kv_size);
9592 ggml_format_name (k, " cache_k_l%d" , i);
9693 ggml_format_name (v, " cache_v_l%d" , i);
9794 k_l[i] = k;
@@ -756,14 +753,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
756753 // Iterate and write all the keys first, each row is a cell
757754 // Get whole range at a time
758755 for (uint32_t il = 0 ; il < n_layer; ++il) {
759- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
760756
761757 // Write key type
762758 const int32_t k_type_i = (int32_t )k_l[il]->type ;
763759 io.write (&k_type_i, sizeof (k_type_i));
764760
765761 // Write row size of key
766- const uint64_t k_size_row = ggml_row_size (k_l[il]->type , n_embd_k_gqa );
762+ const uint64_t k_size_row = ggml_row_size (k_l[il]->type , hparams. n_embd_k_s () );
767763 io.write (&k_size_row, sizeof (k_size_row));
768764
769765 // Read each range of cells of k_size length each into tmp_buf and write out
@@ -776,14 +772,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
776772
777773 if (!v_trans) {
778774 for (uint32_t il = 0 ; il < n_layer; ++il) {
779- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
780775
781776 // Write value type
782777 const int32_t v_type_i = (int32_t )v_l[il]->type ;
783778 io.write (&v_type_i, sizeof (v_type_i));
784779
785780 // Write row size of value
786- const uint64_t v_size_row = ggml_row_size (v_l[il]->type , n_embd_v_gqa );
781+ const uint64_t v_size_row = ggml_row_size (v_l[il]->type , hparams. n_embd_v_s () );
787782 io.write (&v_size_row, sizeof (v_size_row));
788783
789784 // Read each range of cells of v_size length each into tmp_buf and write out
@@ -797,7 +792,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
797792 // When v is transposed, we also need the element size and get the element ranges from each row
798793 const uint32_t kv_size = size;
799794 for (uint32_t il = 0 ; il < n_layer; ++il) {
800- const uint32_t n_embd_v_gqa = hparams. n_embd_v_gqa (il) + hparams.n_embd_v_s ();
795+ const uint32_t n_embd_v_s = hparams.n_embd_v_s ();
801796
802797 // Write value type
803798 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -808,10 +803,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
808803 io.write (&v_size_el, sizeof (v_size_el));
809804
810805 // Write GQA embedding size
811- io.write (&n_embd_v_gqa , sizeof (n_embd_v_gqa ));
806+ io.write (&n_embd_v_s , sizeof (n_embd_v_s ));
812807
813808 // For each row, we get the element values of each cell
814- for (uint32_t j = 0 ; j < n_embd_v_gqa ; ++j) {
809+ for (uint32_t j = 0 ; j < n_embd_v_s ; ++j) {
815810 // Read each range of cells of v_size_el length each into tmp_buf and write out
816811 for (const auto & range : cell_ranges) {
817812 const size_t range_size = range.second - range.first ;
@@ -944,7 +939,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
944939
945940 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
946941 for (uint32_t il = 0 ; il < n_layer; ++il) {
947- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
948942
949943 // Read type of key
950944 int32_t k_type_i_ref;
@@ -958,7 +952,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
958952 // Read row size of key
959953 uint64_t k_size_row_ref;
960954 io.read_to (&k_size_row_ref, sizeof (k_size_row_ref));
961- const size_t k_size_row = ggml_row_size (k_l[il]->type , n_embd_k_gqa );
955+ const size_t k_size_row = ggml_row_size (k_l[il]->type , hparams. n_embd_k_s () );
962956 if (k_size_row != k_size_row_ref) {
963957 LLAMA_LOG_ERROR (" %s: mismatched key row size (%zu != %zu, layer %d)\n " , __func__, k_size_row, (size_t ) k_size_row_ref, il);
964958 return false ;
@@ -972,7 +966,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
972966
973967 if (!v_trans) {
974968 for (uint32_t il = 0 ; il < n_layer; ++il) {
975- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
976969
977970 // Read type of value
978971 int32_t v_type_i_ref;
@@ -986,7 +979,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
986979 // Read row size of value
987980 uint64_t v_size_row_ref;
988981 io.read_to (&v_size_row_ref, sizeof (v_size_row_ref));
989- const size_t v_size_row = ggml_row_size (v_l[il]->type , n_embd_v_gqa );
982+ const size_t v_size_row = ggml_row_size (v_l[il]->type , hparams. n_embd_v_s () );
990983 if (v_size_row != v_size_row_ref) {
991984 LLAMA_LOG_ERROR (" %s: mismatched value row size (%zu != %zu, layer %d)\n " , __func__, v_size_row, (size_t ) v_size_row_ref, il);
992985 return false ;
@@ -1000,7 +993,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
1000993 } else {
1001994 // For each layer, read the values for each cell (transposed)
1002995 for (uint32_t il = 0 ; il < n_layer; ++il) {
1003- const uint32_t n_embd_v_gqa = hparams. n_embd_v_gqa (il) + hparams.n_embd_v_s ();
996+ const uint32_t n_embd_v_s = hparams.n_embd_v_s ();
1004997
1005998 // Read type of value
1006999 int32_t v_type_i_ref;
@@ -1020,17 +1013,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10201013 return false ;
10211014 }
10221015
1023- // Read GQA embedding size
1024- uint32_t n_embd_v_gqa_ref ;
1025- io.read_to (&n_embd_v_gqa_ref , sizeof (n_embd_v_gqa_ref ));
1026- if (n_embd_v_gqa != n_embd_v_gqa_ref ) {
1027- LLAMA_LOG_ERROR (" %s: mismatched GQA embedding size (%u != %u, layer %d)\n " , __func__, n_embd_v_gqa, n_embd_v_gqa_ref , il);
1016+ // Read state embedding size
1017+ uint32_t n_embd_v_s_ref ;
1018+ io.read_to (&n_embd_v_s_ref , sizeof (n_embd_v_s_ref ));
1019+ if (n_embd_v_s != n_embd_v_s_ref ) {
1020+ LLAMA_LOG_ERROR (" %s: mismatched state embedding size (%u != %u, layer %d)\n " , __func__, n_embd_v_s, n_embd_v_s_ref , il);
10281021 return false ;
10291022 }
10301023
10311024 if (cell_count) {
10321025 // For each row in the transposed matrix, read the values for the whole cell range
1033- for (uint32_t j = 0 ; j < n_embd_v_gqa ; ++j) {
1026+ for (uint32_t j = 0 ; j < n_embd_v_s ; ++j) {
10341027 const size_t dst_offset = (head + j * size) * v_size_el;
10351028 ggml_backend_tensor_set (v_l[il], io.read (cell_count * v_size_el), dst_offset, cell_count * v_size_el);
10361029 }
0 commit comments