Skip to content

Commit 188ae84

Browse files
committed
refact: Avoid permute and cont for first cumsum
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 3963a72 commit 188ae84

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

src/llama-model.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11885,9 +11885,9 @@ struct llm_graph_context_mamba : public llm_graph_context {
1188511885
cb(dt_softplus, "dt_softplus", il);
1188611886

1188711887
// step 2: compute dtA and dtX
11888-
/* !! */ ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs}
11888+
ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs}
1188911889
cb(dtA, "dtA", il);
11890-
/* !! */ ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
11890+
ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
1189111891
cb(dtX, "dtX", il);
1189211892

1189311893
// loop over all chunks
@@ -11924,19 +11924,18 @@ struct llm_graph_context_mamba : public llm_graph_context {
1192411924
// step 3: compute CB
1192511925
ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
1192611926
ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
11927-
/* !! */ ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
11927+
ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
1192811928
CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
1192911929
cb(CB, "CB", il);
1193011930

1193111931
// step 4: compute decay
1193211932
ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
11933-
dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
11934-
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
11935-
ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936-
ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933+
dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
11934+
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
11935+
ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
1193711936
cb(segsum, "segsum", il);
11938-
/* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11939-
decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3)); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
11937+
ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
11938+
decay = ggml_permute(ctx, decay, 2, 1, 3, 0); // {chunk_size_i_1, chunk_size_i_0, n_head, n_seqs}
1194011939
cb(decay, "decay", il);
1194111940

1194211941
// step 5: compute surrogate_attention_matrix
@@ -11946,7 +11945,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194611945

1194711946
// step 6: compute y
1194811947
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11949-
/* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11948+
ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
1195011949
y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
1195111950
cb(y_chunk, "y_chunk", il);
1195211951

@@ -11960,18 +11959,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1196011959
B_perm = ggml_cont(ctx, B_perm);
1196111960
B_perm = ggml_repeat_4d(ctx, B_perm,
1196211961
B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
11963-
/* !! */ ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
11962+
ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
1196411963
dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3));
1196511964
cb(dtxdecay, "dtxdecay", il);
1196611965

1196711966
// step 8: compute next_state
11968-
/* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
11967+
ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
1196911968
cb(next_state, "next_state", il);
1197011969

1197111970
// TODO: Skip y and state updates if no previous state
11972-
// FIXME!!! These chunk-recursion parts are not working yet
1197311971

11974-
// update from previous state
11972+
// step 9: update from previous state
1197511973
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
1197611974
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
1197711975
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
@@ -11982,7 +11980,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198211980
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
1198311981
cb(next_state, "next_state_updated", il);
1198411982

11985-
// update from previous y
11983+
// step 10: update from previous y
1198611984
ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
1198711985
cb(y_prev, "y_prev", il);
1198811986
y_prev = ggml_mul(ctx, ggml_cont(ctx,
@@ -11992,7 +11990,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1199211990
y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
1199311991
cb(y_chunk, "y_chunk_updated", il);
1199411992

11995-
// recurse
11993+
// step 11: recurse
1199611994
y = ggml_concat(ctx, y, y_chunk, 2);
1199711995
cb(y, "y", il);
1199811996
ssm = next_state;

0 commit comments

Comments
 (0)