Skip to content

Commit e9660ac

Browse files
committed
feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 1edbb9f commit e9660ac

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/llama-hparams.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68-
uint32_t llama_hparams::n_embd_k_s() const {
68+
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
69+
if (!recurrent_layer(il)) {
70+
return 0;
71+
}
6972
if (wkv_head_size != 0) {
7073
// for RWKV models
7174
return token_shift_count * n_embd;
@@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const {
7679
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
7780
}
7881

79-
uint32_t llama_hparams::n_embd_v_s() const {
82+
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
83+
if (!recurrent_layer(il)) {
84+
return 0;
85+
}
8086
if (wkv_head_size != 0) {
8187
// corresponds to RWKV's wkv_states size
8288
return n_embd * wkv_head_size;
@@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
8692
return ssm_d_state * ssm_d_inner;
8793
}
8894

95+
bool llama_hparams::recurrent_layer(uint32_t il) const {
96+
return recurrent_layer_arr[il];
97+
}
98+
8999
bool llama_hparams::is_swa(uint32_t il) const {
90100
if (il < n_layer) {
91101
return swa_layers[il];

src/llama-hparams.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ struct llama_hparams {
115115
uint32_t ssm_d_state = 0;
116116
uint32_t ssm_dt_rank = 0;
117117

118+
// for hybrid state space models
119+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120+
118121
bool ssm_dt_b_c_rms = false;
119122

120123
float f_clamp_kqv = 0.0f;
@@ -178,10 +181,13 @@ struct llama_hparams {
178181

179182
// dimension of the rolling state embeddings
180183
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
181-
uint32_t n_embd_k_s() const;
184+
uint32_t n_embd_k_s(uint32_t il = 0) const;
182185

183186
// dimension of the recurrent state embeddings
184-
uint32_t n_embd_v_s() const;
187+
uint32_t n_embd_v_s(uint32_t il = 0) const;
188+
189+
// whether or not the given layer is recurrent (for hybrid models)
190+
bool recurrent_layer(uint32_t il) const;
185191

186192
bool is_swa(uint32_t il) const;
187193
};

0 commit comments

Comments
 (0)