@@ -11926,12 +11926,11 @@ struct llm_graph_context_mamba : public llm_graph_context {
1192611926 cb(CB, "CB", il);
1192711927
1192811928 // step 4: compute decay
11929- ggml_tensor * dtA_tmp0 = ggml_cont(ctx, ggml_permute(ctx, dtA_chunk, 2, 1, 3, 0)); // {1, n_seq_tokens n_head, n_seqs}
11930- ggml_tensor * dtA_tmp1 = ggml_repeat_4d(ctx, dtA_tmp0,
11931- dtA_tmp0->ne[0] * chunk_size_i, dtA_tmp0->ne[1], dtA_tmp0->ne[2], dtA_tmp0->ne[3]); // {n_seq_tokens, n_seq_tokens n_head, n_seqs}
11932- ggml_tensor * dtA_tmp2 = ggml_tri_keep(ctx, dtA_tmp1, GGML_TRI_TYPE_LOWER); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
11933- ggml_tensor * dtA_tmp3 = ggml_permute(ctx, dtA_tmp2, 1, 0, 2, 3); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11934- /* !! */ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp3)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11929+ ggml_tensor * dtA_tmp0 = ggml_repeat_4d(ctx, dtA_chunk,
11930+ dtA_chunk->ne[0], dtA_chunk->ne[1], dtA_chunk->ne[2], dtA_chunk->ne[3] * chunk_size_i);
11931+ ggml_tensor * dtA_tmp1 = ggml_tri_dims(ctx, dtA_tmp0, nan(""), GGML_TRI_TYPE_LOWER, 3, 1);
11932+ ggml_tensor * dtA_tmp2 = ggml_permute(ctx, dtA_tmp1, 2, 0, 3, 1); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
11933+ ggml_tensor * segsum = ggml_cumsum(ctx, ggml_cont(ctx, dtA_tmp2)); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193511934 cb(segsum, "segsum", il);
1193611935 /* !! */ ggml_tensor * decay = ggml_exp(ctx, segsum); // {n_seq_tokens_1, n_seq_tokens_0, n_head, n_seqs}
1193711936 decay = ggml_permute(ctx, decay, 1, 0, 2, 3); // {n_seq_tokens_0, n_seq_tokens_1, n_head, n_seqs}
0 commit comments