Skip to content

Commit 3963a72

Browse files
committed
feat(wip): Partially working implementation with update from previous state
We will probably remove the chunking loop in favor of just using the microbatching, but we'll still need this in that case for subsequent microbatches. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 3b4055e commit 3963a72

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

src/llama-model.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11856,13 +11856,16 @@ struct llm_graph_context_mamba : public llm_graph_context {
1185611856
// (this is necessary in order to properly use the states before they are overwritten,
1185711857
// while avoiding to make unnecessary copies of the states)
1185811858
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
11859+
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
11860+
11861+
// Empty y that will be extended with each chunk of tokens
11862+
ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);
1185911863

1186011864
if (n_seq_tokens == 1) {
1186111865
// if (true) {
1186211866
//DEBUG
1186311867
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
1186411868
// If single-token, use ssm_scan op
11865-
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
1186611869
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
1186711870
} else {
1186811871
//DEBUG
@@ -11930,10 +11933,10 @@ struct llm_graph_context_mamba : public llm_graph_context {
1193011933
dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
1193111934
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
1193211935
ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933-
ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936+
ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2), 0); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193411937
cb(segsum, "segsum", il);
1193511938
/* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11936-
decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
11939+
decay = ggml_cont(ctx, ggml_permute(ctx, decay, 1, 0, 2, 3)); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
1193711940
cb(decay, "decay", il);
1193811941

1193911942
// step 5: compute surrogate_attention_matrix
@@ -11943,16 +11946,17 @@ struct llm_graph_context_mamba : public llm_graph_context {
1194311946

1194411947
// step 6: compute y
1194511948
ggml_tensor * dtX_chunk_perm = ggml_cont(ctx, ggml_permute(ctx, dtX_chunk, 1, 2, 0, 3)); //FIXME!!! This could just as easily be (2, 1, 0, 3)
11946-
/* !! */ ggml_tensor * y = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11947-
y = ggml_cont(ctx, ggml_permute(ctx, y, 0, 2, 1, 3));
11948-
cb(y, "y", il);
11949+
/* !! */ ggml_tensor * y_chunk = ggml_mul_mat(ctx, dtX_chunk_perm, surrogate_attention_matrix);
11950+
y_chunk = ggml_cont(ctx, ggml_permute(ctx, y_chunk, 0, 2, 1, 3));
11951+
cb(y_chunk, "y_chunk", il);
1194911952

1195011953
// step 7: compute dtxdecay
1195111954
ggml_tensor * decay_last = ggml_view_4d(ctx, decay,
1195211955
decay->ne[0], 1, decay->ne[2], decay->ne[3],
1195311956
decay->nb[1], decay->nb[2], decay->nb[3],
1195411957
(decay->ne[1] - 1) * decay->nb[1]);
11955-
decay_last = ggml_permute(ctx, decay_last, 2, 0, 1, 3);
11958+
decay_last = ggml_cont(ctx, ggml_permute(ctx, decay_last, 2, 0, 1, 3));
11959+
cb(decay_last, "decay_last", il);
1195611960
B_perm = ggml_cont(ctx, B_perm);
1195711961
B_perm = ggml_repeat_4d(ctx, B_perm,
1195811962
B_perm->ne[0], B_perm->ne[1], B_perm->ne[2] * repeats, B_perm->ne[3]);
@@ -11964,22 +11968,42 @@ struct llm_graph_context_mamba : public llm_graph_context {
1196411968
/* !! */ ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
1196511969
cb(next_state, "next_state", il);
1196611970

11967-
//DEBUG -- Single chunk w/out prev state
11968-
ggml_tensor * out = ggml_concat(ctx,
11969-
ggml_view_1d(ctx, y, ggml_nelements(y), 0),
11970-
ggml_view_1d(ctx, next_state, ggml_nelements(next_state), 0),
11971-
0);
11972-
return out;
11973-
11974-
// // update previous state if present
11975-
// if (true) {
11976-
// // step 9: compute exp_dtA_cumsum
11977-
11978-
// // step 10: compute y_prev
11979-
11980-
// // step 11: update y from y_prev
11981-
// }
11971+
// TODO: Skip y and state updates if no previous state
11972+
// FIXME!!! These chunk-recursion parts are not working yet
11973+
11974+
// update from previous state
11975+
ggml_tensor * exp_dtA_cumsum = ggml_exp(ctx, ggml_cumsum(ctx, dtA_chunk, 1));
11976+
cb(exp_dtA_cumsum, "exp_dtA_cumsum", il);
11977+
ggml_tensor * exp_dtA_cumsum_last = ggml_view_4d(ctx, exp_dtA_cumsum,
11978+
exp_dtA_cumsum->ne[0], 1, exp_dtA_cumsum->ne[2], exp_dtA_cumsum->ne[3],
11979+
exp_dtA_cumsum->nb[1], exp_dtA_cumsum->nb[2], exp_dtA_cumsum->nb[3],
11980+
(exp_dtA_cumsum->ne[1] - 1) * exp_dtA_cumsum->nb[1]);
11981+
cb(exp_dtA_cumsum_last, "exp_dtA_cumsum_last", il);
11982+
next_state = ggml_add(ctx, next_state, ggml_mul(ctx, ssm, ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum_last, 2, 0, 1, 3))));
11983+
cb(next_state, "next_state_updated", il);
11984+
11985+
// update from previous y
11986+
ggml_tensor * y_prev = ggml_mul_mat(ctx, ggml_permute(ctx, C_chunk, 0, 2, 1, 3), ssm);
11987+
cb(y_prev, "y_prev", il);
11988+
y_prev = ggml_mul(ctx, ggml_cont(ctx,
11989+
ggml_cont(ctx, ggml_permute(ctx, y_prev, 2, 0, 1, 3))),
11990+
ggml_cont(ctx, ggml_permute(ctx, exp_dtA_cumsum, 1, 2, 0, 3)));
11991+
cb(y_prev, "y_prev_mul", il);
11992+
y_chunk = ggml_add(ctx, y_chunk, y_prev); //FIXME! Make sure the batch dim is in the right place
11993+
cb(y_chunk, "y_chunk_updated", il);
11994+
11995+
// recurse
11996+
y = ggml_concat(ctx, y, y_chunk, 2);
11997+
cb(y, "y", il);
11998+
ssm = next_state;
1198211999
}
12000+
12001+
// Concat the output y and state
12002+
ggml_tensor * out = ggml_concat(ctx,
12003+
ggml_view_1d(ctx, y, ggml_nelements(y), 0),
12004+
ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0),
12005+
0);
12006+
return out;
1198312007
}
1198412008
};
1198512009

0 commit comments

Comments
 (0)