Skip to content

Commit 62ac897

Browse files
committed
fix: Fix handling of batch size > 1 in chunk updates
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent aba30d6 commit 62ac897

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/llama-model.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11981,24 +11981,25 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198111981
// TODO: Skip y and state updates if no previous state
1198211982

1198311983
// step 9: update from previous state
11984-
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
11984+
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1)); // {n_head, chunk_size_i, n_seqs}
1198511985
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
1198611986
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
1198711987
exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
1198811988
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
11989-
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]);
11989+
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]); // {n_head, 1, n_seqs}
1199011990
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11991-
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
11991+
ggml_tensor * exp_dtA_cumsum_perm = ggml_permute(ctx, exp_dtA_cumsum_last, 1, 2, 3, 0);
11992+
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, exp_dtA_cumsum_perm)));
1199211993
cb(next_state, "next_state_updated", il);
1199311994

1199411995
// step 10: update from previous y
1199511996
ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
1199611997
cb(y_prev, "y_prev", il);
11997-
y_prev = ggml_mul(ctx, ggml_cont(ctx,
11998-
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))),
11999-
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3)));
11998+
y_prev = ggml_mul(ctx,
11999+
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3)),
12000+
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 3, 0)));
1200012001
cb(y_prev, "y_prev_mul", il);
12001-
y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
12002+
y_chunk = ggml_add(ctx, y_chunk, y_prev);
1200212003
cb(y_chunk, "y_chunk_updated", il);
1200312004

1200412005
// step 11: recurse

0 commit comments

Comments
 (0)