Skip to content

Commit 12ecec0

Browse files
authored
Update llama-memory-recurrent.cpp
handle saving/loading null layers in recurrent memory
1 parent 0d92267 commit 12ecec0

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/llama-memory-recurrent.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,11 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
769769
// Get whole range at a time
770770
for (uint32_t il = 0; il < n_layer; ++il) {
771771

772+
if (r_l[il] == nullptr) {
773+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
774+
continue;
775+
}
776+
772777
// Write key type
773778
const int32_t r_type_i = (int32_t)r_l[il]->type;
774779
io.write(&r_type_i, sizeof(r_type_i));
@@ -788,6 +793,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
788793
if (!s_trans) {
789794
for (uint32_t il = 0; il < n_layer; ++il) {
790795

796+
// special key to handle null layers
797+
if (s_l[il] == nullptr) {
798+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
799+
continue;
800+
}
801+
791802
// Write value type
792803
const int32_t s_type_i = (int32_t)s_l[il]->type;
793804
io.write(&s_type_i, sizeof(s_type_i));
@@ -807,6 +818,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807818
// When v is transposed, we also need the element size and get the element ranges from each row
808819
const uint32_t mem_size = size;
809820
for (uint32_t il = 0; il < n_layer; ++il) {
821+
// special key to handle null layers
822+
if (s_l[il] == nullptr) {
823+
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
824+
continue;
825+
}
826+
810827
const uint32_t n_embd_s = hparams.n_embd_s();
811828

812829
// Write value type
@@ -951,6 +968,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951968

952969
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953970
for (uint32_t il = 0; il < n_layer; ++il) {
971+
// skip null layers
972+
if(r_l[il] == nullptr) continue;
954973

955974
// Read type of key
956975
int32_t r_type_i_ref;
@@ -978,11 +997,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978997

979998
if (!s_trans) {
980999
for (uint32_t il = 0; il < n_layer; ++il) {
1000+
// skip null layers
1001+
if(s_l[il] == nullptr) continue;
9811002

9821003
// Read type of value
9831004
int32_t s_type_i_ref;
9841005
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
9851006
const int32_t s_type_i = (int32_t)s_l[il]->type;
1007+
9861008
if (s_type_i != s_type_i_ref) {
9871009
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
9881010
return false;
@@ -1005,6 +1027,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10051027
} else {
10061028
// For each layer, read the values for each cell (transposed)
10071029
for (uint32_t il = 0; il < n_layer; ++il) {
1030+
// skip null layers
1031+
if(s_l[il] == nullptr) continue;
1032+
10081033
const uint32_t n_embd_s = hparams.n_embd_s();
10091034

10101035
// Read type of value

0 commit comments

Comments
 (0)