@@ -11856,13 +11856,16 @@ struct llm_graph_context_mamba : public llm_graph_context {
1185611856 // (this is necessary in order to properly use the states before they are overwritten,
1185711857 // while avoiding to make unnecessary copies of the states)
1185811858 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
11859+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
11860+
11861+ // Empty y that will be extended with each chunk of tokens
11862+ ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);
1185911863
1186011864 if (n_seq_tokens == 1) {
1186111865 // if (true) {
1186211866 //DEBUG
1186311867 LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
1186411868 // If single-token, use ssm_scan op
11865- ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
1186611869 return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
1186711870 } else {
1186811871 //DEBUG
@@ -11930,10 +11933,10 @@ struct llm_graph_context_mamba : public llm_graph_context {
1193011933 dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
1193111934 ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
1193211935 ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933- ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936+ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0 ); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193411937 cb(segsum, "segsum", il);
1193511938 /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936- decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
11939+ decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3) ); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
1193711940 cb(decay, "decay", il);
1193811941
1193911942 // step 5: compute surrogate_attention_matrix
@@ -11943,16 +11946,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194311946
1194411947 // step 6: compute y
1194511948 ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11946- /* !! */ ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11947- y = ggml_cont(ctx, ggml_permute(ctx, y , 0, 2, 1, 3));
11948- cb(y , "y ", il);
11949+ /* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11950+ y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk , 0, 2, 1, 3));
11951+ cb(y_chunk , "y_chunk ", il);
1194911952
1195011953 // step 7: compute dtxdecay
1195111954 ggml_tensor * decay_last = ggml_view_4d(ctx, decay,
1195211955 decay->ne[0], 1, decay->ne[2], decay->ne[3],
1195311956 decay->nb[1], decay->nb[2], decay->nb[3],
1195411957 (decay->ne[1] - 1) * decay->nb[1]);
11955- decay_last = ggml_permute(ctx, decay_last, 2, 0, 1, 3);
11958+ decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3));
11959+ cb(decay_last, "decay_last", il);
1195611960 B_perm = ggml_cont(ctx, B_perm);
1195711961 B_perm = ggml_repeat_4d(ctx, B_perm,
1195811962 B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
@@ -11964,22 +11968,42 @@ struct llm_graph_context_mamba : public llm_graph_context {
1196411968 /* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
1196511969 cb(next_state, "next_state", il);
1196611970
11967- //DEBUG -- Single chunk w/out prev state
11968- ggml_tensor * out = ggml_concat(ctx,
11969- ggml_view_1d(ctx, y, ggml_nelements(y), 0),
11970- ggml_view_1d(ctx, next_state, ggml_nelements(next_state), 0),
11971- 0);
11972- return out;
11973-
11974- // // update previous state if present
11975- // if (true) {
11976- // // step 9: compute exp_dtA_cumsum
11977-
11978- // // step 10: compute y_prev
11979-
11980- // // step 11: update y from y_prev
11981- // }
11971+ // TODO: Skip y and state updates if no previous state
11972+ // FIXME!!! These chunk-recursion parts are not working yet
11973+
11974+ // update from previous state
11975+ ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
11976+ cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
11977+ ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
11978+ exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
11979+ exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
11980+ (exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]);
11981+ cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11982+ next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
11983+ cb(next_state, "next_state_updated", il);
11984+
11985+ // update from previous y
11986+ ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
11987+ cb(y_prev, "y_prev", il);
11988+ y_prev = ggml_mul(ctx, ggml_cont(ctx,
11989+ ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))),
11990+ ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3)));
11991+ cb(y_prev, "y_prev_mul", il);
11992+ y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
11993+ cb(y_chunk, "y_chunk_updated", il);
11994+
11995+ // recurse
11996+ y = ggml_concat(ctx, y, y_chunk, 2);
11997+ cb(y, "y", il);
11998+ ssm = next_state;
1198211999 }
12000+
12001+ // Concat the output y and state
12002+ ggml_tensor * out = ggml_concat(ctx,
12003+ ggml_view_1d(ctx, y, ggml_nelements(y), 0),
12004+ ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0),
12005+ 0);
12006+ return out;
1198312007 }
1198412008 };
1198512009
0 commit comments