Skip to content

Commit 36244fe

Browse files
committed
fix: Fix permutation for nemotron-h shape
Something is definitely still broken for nemotron-h which may be the g > 1 aspect of the model Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 62ac897 commit 36244fe

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/llama-model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11879,13 +11879,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
1187911879
// extract the state(s) for the sequences identified by ids
1188011880
if (ssm->ne[3] != ids->ne[0]) {
1188111881
ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
11882-
// ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0
1188311882
ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids,
1188411883
ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape
1188511884
ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows
1188611885
ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape
1188711886
GGML_ASSERT(ssm->ne[3] == ids->ne[0]);
1188811887
}
11888+
// ssm -> {d_state, head_dim, n_head, n_seqs}
1188911889

1189011890
// step 1: compute dt softplus
1189111891
// NOTE: In other implementations, the bias is added after
@@ -11988,7 +11988,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198811988
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
1198911989
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
1199011990
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11991-
ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2, 3, 0);
11991+
ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1, 3, 0); // {1, 1, n_head, n_seqs}
1199211992
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
1199311993
cb(next_state, "next_state_updated", il);
1199411994

0 commit comments

Comments
 (0)