Skip to content

Commit 426a97c

Browse files
committed
feat: Keep ssm in f16 until output on SSD code path
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent de43d0b commit 426a97c

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/models/graph-context-mamba.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
242242
// while avoiding to make unnecessary copies of the states)
243243
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
244244
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
245-
ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32);
246245

247246
// Empty y that will be extended with each chunk of tokens
248247
ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);
@@ -252,6 +251,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
252251
//DEBUG
253252
LLAMA_LOG_DEBUG("build_mamba2_layer(layer %d): single-token update\n", il);
254253
// If single-token, use ssm_scan op
254+
ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32);
255255
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
256256
} else {
257257
//DEBUG
@@ -362,6 +362,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
362362

363363
// step 8: compute next_state
364364
ggml_tensor * next_state = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_permute(ctx, B_perm, 1, 0, 2, 3)), dtxdecay);
365+
if (next_state->type != ssm->type) {
366+
next_state = ggml_cast(ctx, next_state, ssm->type);
367+
}
365368
cb(next_state, "next_state", il);
366369

367370
// TODO: Skip y and state updates if no previous state
@@ -395,6 +398,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
395398
}
396399

397400
// Concat the output y and state
401+
if (ssm->type != y->type) {
402+
ssm = ggml_cast(ctx, ssm, y->type);
403+
}
398404
ggml_tensor * out = ggml_concat(ctx,
399405
ggml_view_1d(ctx, y, ggml_nelements(y), 0),
400406
ggml_view_1d(ctx, ssm, ggml_nelements(ssm), 0),

0 commit comments

Comments
 (0)