@@ -11879,13 +11879,13 @@ struct llm_graph_context_mamba : public llm_graph_context {
1187911879 // extract the state(s) for the sequences identified by ids
1188011880 if (ssm->ne[3] != ids->ne[0]) {
1188111881 ggml_tensor * ssm_perm = ggml_permute(ctx, ssm, 0, 2, 3, 1); // put the target dim in dim 1
11882- // ggml_tensor * ids_perm = ggml_permute(ctx, ids, 1, 2, 3, 0); // put the taget dim in dim 0
1188311882 ggml_tensor * ids_perm_rep = ggml_repeat_4d(ctx, ids,
1188411883 ids->ne[0], ssm->ne[1], ssm->ne[2], 1); // repeat to match expected shape
1188511884 ggml_tensor * ssm_ids = ggml_get_rows(ctx, ssm_perm, ids_perm_rep); // extract ids as rows
1188611885 ssm = ggml_cont(ctx, ggml_permute(ctx, ssm_ids, 0, 3, 1, 2)); // permute back to original shape
1188711886 GGML_ASSERT(ssm->ne[3] == ids->ne[0]);
1188811887 }
11888+ // ssm -> {d_state, head_dim, n_head, n_seqs}
1188911889
1189011890 // step 1: compute dt softplus
1189111891 // NOTE: In other implementations, the bias is added after
@@ -11988,7 +11988,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198811988 exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
1198911989 (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
1199011990 cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11991- ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2 , 3, 0);
11991+ ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 2, 1 , 3, 0); // {1, 1, n_head, n_seqs}
1199211992 next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
1199311993 cb(next_state, "next_state_updated", il);
1199411994
0 commit comments