@@ -11981,24 +11981,25 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198111981 // TODO: Skip y and state updates if no previous state
1198211982
1198311983 // step 9: update from previous state
11984- ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
11984+ ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); // {n_head, chunk_size_i, n_seqs}
1198511985 cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
1198611986 ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
1198711987 exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
1198811988 exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
11989- (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]);
11989+ (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
1199011990 cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11991- next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
11991+ ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2, 3, 0);
11992+ next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
1199211993 cb(next_state, "next_state_updated", il);
1199311994
1199411995 // step 10: update from previous y
1199511996 ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
1199611997 cb(y_prev, "y_prev", il);
11997- y_prev = ggml_mul(ctx, ggml_cont(ctx,
11998- ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))) ,
11999- ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3 )));
11998+ y_prev = ggml_mul(ctx,
11999+ ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)),
12000+ ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0 )));
1200012001 cb(y_prev, "y_prev_mul", il);
12001- y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
12002+ y_chunk = ggml_add(ctx, y_chunk, y_prev);
1200212003 cb(y_chunk, "y_chunk_updated", il);
1200312004
1200412005 // step 11: recurse
0 commit comments