@@ -769,6 +769,11 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
769769 // Get whole range at a time
770770 for (uint32_t il = 0 ; il < n_layer; ++il) {
771771
772+ if (r_l[il] == nullptr ) {
773+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
774+ continue ;
775+ }
776+
772777 // Write key type
773778 const int32_t r_type_i = (int32_t )r_l[il]->type ;
774779 io.write (&r_type_i, sizeof (r_type_i));
@@ -788,6 +793,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
788793 if (!s_trans) {
789794 for (uint32_t il = 0 ; il < n_layer; ++il) {
790795
796+ // special key to handle null layers
797+ if (s_l[il] == nullptr ) {
798+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
799+ continue ;
800+ }
801+
791802 // Write value type
792803 const int32_t s_type_i = (int32_t )s_l[il]->type ;
793804 io.write (&s_type_i, sizeof (s_type_i));
@@ -807,6 +818,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807818 // When v is transposed, we also need the element size and get the element ranges from each row
808819 const uint32_t mem_size = size;
809820 for (uint32_t il = 0 ; il < n_layer; ++il) {
821+ // special key to handle null layers
822+ if (s_l[il] == nullptr ) {
823+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
824+ continue ;
825+ }
826+
810827 const uint32_t n_embd_s = hparams.n_embd_s ();
811828
812829 // Write value type
@@ -951,6 +968,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951968
952969 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953970 for (uint32_t il = 0 ; il < n_layer; ++il) {
971+ // skip null layers
972+ if (r_l[il] == nullptr ) continue ;
954973
955974 // Read type of key
956975 int32_t r_type_i_ref;
@@ -978,11 +997,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978997
979998 if (!s_trans) {
980999 for (uint32_t il = 0 ; il < n_layer; ++il) {
1000+ // skip null layers
1001+ if (s_l[il] == nullptr ) continue ;
9811002
9821003 // Read type of value
9831004 int32_t s_type_i_ref;
9841005 io.read_to (&s_type_i_ref, sizeof (s_type_i_ref));
9851006 const int32_t s_type_i = (int32_t )s_l[il]->type ;
1007+
9861008 if (s_type_i != s_type_i_ref) {
9871009 LLAMA_LOG_ERROR (" %s: mismatched s type (%d != %d, layer %d)\n " , __func__, s_type_i, s_type_i_ref, il);
9881010 return false ;
@@ -1005,6 +1027,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10051027 } else {
10061028 // For each layer, read the values for each cell (transposed)
10071029 for (uint32_t il = 0 ; il < n_layer; ++il) {
1030+ // skip null layers
1031+ if (s_l[il] == nullptr ) continue ;
1032+
10081033 const uint32_t n_embd_s = hparams.n_embd_s ();
10091034
10101035 // Read type of value
0 commit comments