Skip to content

Commit e15fa60

Browse files
committed
refactor: Remove layer index from n_embd_k/v_s
Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 4cd505f commit e15fa60

File tree

4 files changed

+14
-20
lines changed

4 files changed

+14
-20
lines changed

src/llama-hparams.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68-
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
69-
if (!recurrent_layer(il)) {
70-
return 0;
71-
}
68+
uint32_t llama_hparams::n_embd_k_s() const {
7269
if (wkv_head_size != 0) {
7370
// for RWKV models
7471
return token_shift_count * n_embd;
@@ -79,10 +76,7 @@ uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
7976
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
8077
}
8178

82-
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
83-
if (!recurrent_layer(il)) {
84-
return 0;
85-
}
79+
uint32_t llama_hparams::n_embd_v_s() const {
8680
if (wkv_head_size != 0) {
8781
// corresponds to RWKV's wkv_states size
8882
return n_embd * wkv_head_size;

src/llama-hparams.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ struct llama_hparams {
184184

185185
// dimension of the rolling state embeddings
186186
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
187-
uint32_t n_embd_k_s(uint32_t il = 0) const;
187+
uint32_t n_embd_k_s() const;
188188

189189
// dimension of the recurrent state embeddings
190-
uint32_t n_embd_v_s(uint32_t il = 0) const;
190+
uint32_t n_embd_v_s() const;
191191

192192
// whether or not the given layer is recurrent (for hybrid models)
193193
bool recurrent_layer(uint32_t il) const;

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
72+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
73+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
7474

7575
const char * dev_name = "CPU";
7676

@@ -756,7 +756,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
756756
// Iterate and write all the keys first, each row is a cell
757757
// Get whole range at a time
758758
for (uint32_t il = 0; il < n_layer; ++il) {
759-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
759+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
760760

761761
// Write key type
762762
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -776,7 +776,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
776776

777777
if (!v_trans) {
778778
for (uint32_t il = 0; il < n_layer; ++il) {
779-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
779+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
780780

781781
// Write value type
782782
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -797,7 +797,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
797797
// When v is transposed, we also need the element size and get the element ranges from each row
798798
const uint32_t kv_size = size;
799799
for (uint32_t il = 0; il < n_layer; ++il) {
800-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
800+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
801801

802802
// Write value type
803803
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -944,7 +944,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
944944

945945
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
946946
for (uint32_t il = 0; il < n_layer; ++il) {
947-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
947+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
948948

949949
// Read type of key
950950
int32_t k_type_i_ref;
@@ -972,7 +972,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
972972

973973
if (!v_trans) {
974974
for (uint32_t il = 0; il < n_layer; ++il) {
975-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
975+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
976976

977977
// Read type of value
978978
int32_t v_type_i_ref;
@@ -1000,7 +1000,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
10001000
} else {
10011001
// For each layer, read the values for each cell (transposed)
10021002
for (uint32_t il = 0; il < n_layer; ++il) {
1003-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
1003+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
10041004

10051005
// Read type of value
10061006
int32_t v_type_i_ref;

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8939,11 +8939,11 @@ struct llm_build_mamba : public llm_graph_context {
89398939
// (ab)using the KV cache to store the states
89408940
ggml_tensor * conv = build_recurrent_state(
89418941
gf, conv_states_all, state_copy,
8942-
hparams.n_embd_k_s(il), n_seqs);
8942+
hparams.n_embd_k_s(), n_seqs);
89438943
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
89448944
ggml_tensor * ssm = build_recurrent_state(
89458945
gf, ssm_states_all, state_copy,
8946-
hparams.n_embd_v_s(il), n_seqs);
8946+
hparams.n_embd_v_s(), n_seqs);
89478947
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
89488948

89498949
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}

0 commit comments

Comments
 (0)