Skip to content

Commit 86788a2

Browse files
committed
temp: Cast ssm to F32
This will be needed until F16 support is added for SSM_SCAN Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 204cd80 commit 86788a2

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/models/graph-context-mamba.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
242242
// while avoiding to make unnecessary copies of the states)
243243
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
244244
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
245+
ssm = ggml_cast(ctx, ssm, GGML_TYPE_F32);
245246

246247
// Empty y that will be extended with each chunk of tokens
247248
ggml_tensor * y = ggml_new_tensor_4d(ctx, x->type, x->ne[0], x->ne[1], 0, x->ne[3]);

0 commit comments

Comments
 (0)