Skip to content

Commit bb87dbf

Browse files
committed
fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent c831e76 commit bb87dbf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8939,11 +8939,11 @@ struct llm_build_mamba : public llm_graph_context {
89398939
// (ab)using the KV cache to store the states
89408940
ggml_tensor * conv = build_recurrent_state(
89418941
gf, conv_states_all, state_copy,
8942-
hparams.n_embd_k_s(), n_seqs);
8942+
hparams.n_embd_k_s(il), n_seqs);
89438943
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
89448944
ggml_tensor * ssm = build_recurrent_state(
89458945
gf, ssm_states_all, state_copy,
8946-
hparams.n_embd_v_s(), n_seqs);
8946+
hparams.n_embd_v_s(il), n_seqs);
89478947
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
89488948

89498949
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}

0 commit comments

Comments
 (0)