Skip to content

Commit 3da5c97

Browse files
committed
fix(temp): Fix CBdecay to make decay contiguous for metal
We shouldn't need this once cumsum can operate on other dims and we can avoid all the various permutes elsewhere. Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ef12069 commit 3da5c97

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11938,7 +11938,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
1193811938
cb(decay, "decay", il);
1193911939

1194011940
// step 5: compute surrogate_attention_matrix
11941-
/* !! */ ggml_tensor * CBdecay = ggml_mul(ctx, CB, decay);
11941+
ggml_tensor * CBdecay = ggml_mul(ctx, CB, ggml_cont(ctx, decay));
1194211942
ggml_tensor * surrogate_attention_matrix = ggml_tri_keep(ctx, CBdecay, GGML_TRI_TYPE_LOWER_DIAG);
1194311943
cb(surrogate_attention_matrix, "surrogate_attention_matrix", il);
1194411944

0 commit comments

Comments
 (0)