Skip to content

Commit 3336f3c

Browse files
committed
fix: Use ggml_tri_dims to avoid perm/cont for initial decay step
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 3da5c97 commit 3336f3c

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/llama-model.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11926,12 +11926,11 @@ struct llm_graph_context_mamba : public llm_graph_context {
1192611926
cb(CB, "CB", il);
1192711927

1192811928
// step 4: compute decay
11929-
ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs}
11930-
ggml_tensor * dtA_tmp1 = ggml_repeat_4d(ctx, dtA_tmp0,
11931-
dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs}
11932-
ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
11933-
ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11934-
/* !! */ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11929+
ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
11930+
dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
11931+
ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
11932+
ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933+
ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193511934
cb(segsum, "segsum", il);
1193611935
/* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193711936
decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}

0 commit comments

Comments
 (0)