@@ -11885,9 +11885,9 @@ struct llm_graph_context_mamba : public llm_graph_context {
1188511885 cb(dt_softplus, "dt_softplus", il);
1188611886
1188711887 // step 2: compute dtA and dtX
11888- /* !! */ ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs}
11888+ ggml_tensor * dtA = ggml_mul(ctx, dt_softplus, ggml_reshape_1d(ctx, A, A->ne[1])); // {n_head, n_seq_tokens, n_seqs}
1188911889 cb(dtA, "dtA", il);
11890- /* !! */ ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
11890+ ggml_tensor * dtX = ggml_mul(ctx, x, ggml_reshape_4d(ctx, dt_softplus, 1, dt_softplus->ne[0], dt_softplus->ne[1], dt_softplus->ne[2])); // {head_dim, n_head, n_seq_tokens, n_seqs}
1189111891 cb(dtX, "dtX", il);
1189211892
1189311893 // loop over all chunks
@@ -11924,19 +11924,18 @@ struct llm_graph_context_mamba : public llm_graph_context {
1192411924 // step 3: compute CB
1192511925 ggml_tensor * C_perm = ggml_permute(ctx, C_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
1192611926 ggml_tensor * B_perm = ggml_permute(ctx, B_chunk, 0, 2, 1, 3); // {d_state, n_seq_tokens, n_group, n_seqs}
11927- /* !! */ ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
11927+ ggml_tensor * CB = ggml_mul_mat(ctx, B_perm, C_perm); // {n_seq_tokens, n_seq_tokens, n_group, n_seqs}
1192811928 CB = ggml_repeat_4d(ctx, CB, CB->ne[0], CB->ne[1], CB->ne[2] * repeats, CB->ne[3]); // {n_seq_tokens, n_seq_tokens, n_head (repeats * n_group), n_seqs}
1192911929 cb(CB, "CB", il);
1193011930
1193111931 // step 4: compute decay
1193211932 ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
11933- dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
11934- ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
11935- ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936- ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933+ dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
11934+ ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
11935+ ggml_tensor * segsum = ggml_cumsum(ctx, dtA_tmp1, 1); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1}
1193711936 cb(segsum, "segsum", il);
11938- /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs }
11939- decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3)) ; // {n_seq_tokens_0, n_seq_tokens_1 , n_head, n_seqs}
11937+ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_head, chunk_size_i_0, n_seqs, chunk_size_i_1 }
11938+ decay = ggml_permute(ctx, decay, 2, 1, 3, 0) ; // {chunk_size_i_1, chunk_size_i_0 , n_head, n_seqs}
1194011939 cb(decay, "decay", il);
1194111940
1194211941 // step 5: compute surrogate_attention_matrix
@@ -11946,7 +11945,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194611945
1194711946 // step 6: compute y
1194811947 ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11949- /* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11948+ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
1195011949 y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
1195111950 cb(y_chunk, "y_chunk", il);
1195211951
@@ -11960,18 +11959,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1196011959 B_perm = ggml_cont(ctx, B_perm);
1196111960 B_perm = ggml_repeat_4d(ctx, B_perm,
1196211961 B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
11963- /* !! */ ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
11962+ ggml_tensor * dtxdecay = ggml_mul(ctx, dtX_chunk, decay_last);
1196411963 dtxdecay = ggml_cont(ctx, ggml_permute(ctx, dtxdecay, 1, 2, 0, 3));
1196511964 cb(dtxdecay, "dtxdecay", il);
1196611965
1196711966 // step 8: compute next_state
11968- /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
11967+ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
1196911968 cb(next_state, "next_state", il);
1197011969
1197111970 // TODO: Skip y and state updates if no previous state
11972- // FIXME!!! These chunk-recursion parts are not working yet
1197311971
11974- // update from previous state
11972+ // step 9: update from previous state
1197511973 ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
1197611974 cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
1197711975 ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
@@ -11982,7 +11980,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1198211980 next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
1198311981 cb(next_state, "next_state_updated", il);
1198411982
11985- // update from previous y
11983+ // step 10: update from previous y
1198611984 ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
1198711985 cb(y_prev, "y_prev", il);
1198811986 y_prev = ggml_mul(ctx, ggml_cont(ctx,
@@ -11992,7 +11990,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1199211990 y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
1199311991 cb(y_chunk, "y_chunk_updated", il);
1199411992
11995- // recurse
11993+ // step 11: recurse
1199611994 y = ggml_concat(ctx, y, y_chunk, 2);
1199711995 cb(y, "y", il);
1199811996 ssm = next_state;
0 commit comments