@@ -9024,9 +9024,9 @@ struct llm_build_mamba : public llm_graph_context {
90249024 ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
90259025
90269026 // (ab)using the KV cache to store the states
9027- ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
9027+ ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il ), n_seqs);
90289028 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
9029- ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
9029+ ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il ), n_seqs, true);
90309030 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
90319031
90329032 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -9156,9 +9156,11 @@ struct llm_build_mamba : public llm_graph_context {
91569156 ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
91579157
91589158 // (ab)using the KV cache to store the states
9159- ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(), n_seqs);
9159+ LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state conv\n", __func__, il);
9160+ ggml_tensor * conv = self->build_recurrent_state(gf, conv_states_all, state_copy, self->hparams.n_embd_k_s(il), n_seqs);
91609161 conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
9161- ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(), n_seqs, true);
9162+ LLAMA_LOG_DEBUG("%s[%d]: Building recurrent state ssm\n", __func__, il);
9163+ ggml_tensor * ssm = self->build_recurrent_state(gf, ssm_states_all, state_copy, self->hparams.n_embd_v_s(il), n_seqs, true);
91629164 ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
91639165
91649166 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
0 commit comments