@@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969 continue ;
7070 }
7171
72- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (i) + hparams.n_embd_k_s ();
73- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (i) + hparams.n_embd_v_s ();
74-
7572 const char * dev_name = " CPU" ;
7673
7774 ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type ();
@@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
9087 throw std::runtime_error (" failed to create ggml context for kv cache" );
9188 }
9289
93- ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, n_embd_k_gqa *kv_size);
94- ggml_tensor * v = ggml_new_tensor_1d (ctx, type_v, n_embd_v_gqa *kv_size);
90+ ggml_tensor * k = ggml_new_tensor_1d (ctx, type_k, hparams. n_embd_k_s () *kv_size);
91+ ggml_tensor * v = ggml_new_tensor_1d (ctx, type_v, hparams. n_embd_v_s () *kv_size);
9592 ggml_format_name (k, " cache_k_l%d" , i);
9693 ggml_format_name (v, " cache_v_l%d" , i);
9794 k_l[i] = k;
@@ -754,14 +751,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
754751 // Iterate and write all the keys first, each row is a cell
755752 // Get whole range at a time
756753 for (uint32_t il = 0 ; il < n_layer; ++il) {
757- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
758754
759755 // Write key type
760756 const int32_t k_type_i = (int32_t )k_l[il]->type ;
761757 io.write (&k_type_i, sizeof (k_type_i));
762758
763759 // Write row size of key
764- const uint64_t k_size_row = ggml_row_size (k_l[il]->type , n_embd_k_gqa );
760+ const uint64_t k_size_row = ggml_row_size (k_l[il]->type , hparams. n_embd_k_s () );
765761 io.write (&k_size_row, sizeof (k_size_row));
766762
767763 // Read each range of cells of k_size length each into tmp_buf and write out
@@ -774,14 +770,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
774770
775771 if (!v_trans) {
776772 for (uint32_t il = 0 ; il < n_layer; ++il) {
777- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
778773
779774 // Write value type
780775 const int32_t v_type_i = (int32_t )v_l[il]->type ;
781776 io.write (&v_type_i, sizeof (v_type_i));
782777
783778 // Write row size of value
784- const uint64_t v_size_row = ggml_row_size (v_l[il]->type , n_embd_v_gqa );
779+ const uint64_t v_size_row = ggml_row_size (v_l[il]->type , hparams. n_embd_v_s () );
785780 io.write (&v_size_row, sizeof (v_size_row));
786781
787782 // Read each range of cells of v_size length each into tmp_buf and write out
@@ -795,7 +790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
795790 // When v is transposed, we also need the element size and get the element ranges from each row
796791 const uint32_t kv_size = size;
797792 for (uint32_t il = 0 ; il < n_layer; ++il) {
798- const uint32_t n_embd_v_gqa = hparams. n_embd_v_gqa (il) + hparams.n_embd_v_s ();
793+ const uint32_t n_embd_v_s = hparams.n_embd_v_s ();
799794
800795 // Write value type
801796 const int32_t v_type_i = (int32_t )v_l[il]->type ;
@@ -806,10 +801,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
806801 io.write (&v_size_el, sizeof (v_size_el));
807802
808803 // Write GQA embedding size
809- io.write (&n_embd_v_gqa , sizeof (n_embd_v_gqa ));
804+ io.write (&n_embd_v_s , sizeof (n_embd_v_s ));
810805
811806 // For each row, we get the element values of each cell
812- for (uint32_t j = 0 ; j < n_embd_v_gqa ; ++j) {
807+ for (uint32_t j = 0 ; j < n_embd_v_s ; ++j) {
813808 // Read each range of cells of v_size_el length each into tmp_buf and write out
814809 for (const auto & range : cell_ranges) {
815810 const size_t range_size = range.second - range.first ;
@@ -942,7 +937,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
942937
943938 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
944939 for (uint32_t il = 0 ; il < n_layer; ++il) {
945- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa (il) + hparams.n_embd_k_s ();
946940
947941 // Read type of key
948942 int32_t k_type_i_ref;
@@ -956,7 +950,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
956950 // Read row size of key
957951 uint64_t k_size_row_ref;
958952 io.read_to (&k_size_row_ref, sizeof (k_size_row_ref));
959- const size_t k_size_row = ggml_row_size (k_l[il]->type , n_embd_k_gqa );
953+ const size_t k_size_row = ggml_row_size (k_l[il]->type , hparams. n_embd_k_s () );
960954 if (k_size_row != k_size_row_ref) {
961955 LLAMA_LOG_ERROR (" %s: mismatched key row size (%zu != %zu, layer %d)\n " , __func__, k_size_row, (size_t ) k_size_row_ref, il);
962956 return false ;
@@ -970,7 +964,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
970964
971965 if (!v_trans) {
972966 for (uint32_t il = 0 ; il < n_layer; ++il) {
973- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa (il) + hparams.n_embd_v_s ();
974967
975968 // Read type of value
976969 int32_t v_type_i_ref;
@@ -984,7 +977,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
984977 // Read row size of value
985978 uint64_t v_size_row_ref;
986979 io.read_to (&v_size_row_ref, sizeof (v_size_row_ref));
987- const size_t v_size_row = ggml_row_size (v_l[il]->type , n_embd_v_gqa );
980+ const size_t v_size_row = ggml_row_size (v_l[il]->type , hparams. n_embd_v_s () );
988981 if (v_size_row != v_size_row_ref) {
989982 LLAMA_LOG_ERROR (" %s: mismatched value row size (%zu != %zu, layer %d)\n " , __func__, v_size_row, (size_t ) v_size_row_ref, il);
990983 return false ;
@@ -998,7 +991,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
998991 } else {
999992 // For each layer, read the values for each cell (transposed)
1000993 for (uint32_t il = 0 ; il < n_layer; ++il) {
1001- const uint32_t n_embd_v_gqa = hparams. n_embd_v_gqa (il) + hparams.n_embd_v_s ();
994+ const uint32_t n_embd_v_s = hparams.n_embd_v_s ();
1002995
1003996 // Read type of value
1004997 int32_t v_type_i_ref;
@@ -1018,17 +1011,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10181011 return false ;
10191012 }
10201013
1021- // Read GQA embedding size
1022- uint32_t n_embd_v_gqa_ref ;
1023- io.read_to (&n_embd_v_gqa_ref , sizeof (n_embd_v_gqa_ref ));
1024- if (n_embd_v_gqa != n_embd_v_gqa_ref ) {
1025- LLAMA_LOG_ERROR (" %s: mismatched GQA embedding size (%u != %u, layer %d)\n " , __func__, n_embd_v_gqa, n_embd_v_gqa_ref , il);
1014+ // Read state embedding size
1015+ uint32_t n_embd_v_s_ref ;
1016+ io.read_to (&n_embd_v_s_ref , sizeof (n_embd_v_s_ref ));
1017+ if (n_embd_v_s != n_embd_v_s_ref ) {
1018+ LLAMA_LOG_ERROR (" %s: mismatched state embedding size (%u != %u, layer %d)\n " , __func__, n_embd_v_s, n_embd_v_s_ref , il);
10261019 return false ;
10271020 }
10281021
10291022 if (cell_count) {
10301023 // For each row in the transposed matrix, read the values for the whole cell range
1031- for (uint32_t j = 0 ; j < n_embd_v_gqa ; ++j) {
1024+ for (uint32_t j = 0 ; j < n_embd_v_s ; ++j) {
10321025 const size_t dst_offset = (head + j * size) * v_size_el;
10331026 ggml_backend_tensor_set (v_l[il], io.read (cell_count * v_size_el), dst_offset, cell_count * v_size_el);
10341027 }
0 commit comments