Skip to content

Commit 52cd6d1

Browse files
committed
fix: Use per-layer sizes in mamba layer builders
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent f864a1c commit 52cd6d1

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/llama-model.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9024,9 +9024,9 @@ struct llm_build_mamba : public llm_graph_context {
90249024
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
90259025

90269026
// (ab)using the KV cache to store the states
9027-
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
9027+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
90289028
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9029-
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
9029+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
90309030
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
90319031

90329032
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9156,9 +9156,11 @@ struct llm_build_mamba : public llm_graph_context {
91569156
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91579157

91589158
// (ab)using the KV cache to store the states
9159-
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
9159+
LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state conv\n", __func__, il);
9160+
ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
91609161
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9161-
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
9162+
LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state ssm\n", __func__, il);
9163+
ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
91629164
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91639165

91649166
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}

0 commit comments

Comments
 (0)