@@ -242,7 +242,6 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
242242 // while avoiding to make unnecessary copies of the states)
243243 auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
244244 ggml_tensor * ssm = ggml_reshape_4d (ctx, states, d_state, head_dim, n_head, mctx_cur->get_size ());
245- ssm = ggml_cast (ctx, ssm, GGML_TYPE_F32);
246245
247246 // Empty y that will be extended with each chunk of tokens
248247 ggml_tensor * y = ggml_new_tensor_4d (ctx, x->type , x->ne [0 ], x->ne [1 ], 0 , x->ne [3 ]);
@@ -252,6 +251,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
252251 // DEBUG
253252 LLAMA_LOG_DEBUG (" build_mamba2_layer(layer %d): single-token update\n " , il);
254253 // If single-token, use ssm_scan op
254+ ssm = ggml_cast (ctx, ssm, GGML_TYPE_F32);
255255 return ggml_ssm_scan (ctx, ssm, x, dt, A, B, C, ids);
256256 } else {
257257 // DEBUG
@@ -362,6 +362,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
362362
363363 // step 8: compute next_state
364364 ggml_tensor * next_state = ggml_mul_mat (ctx, ggml_cont (ctx, ggml_permute (ctx, B_perm, 1 , 0 , 2 , 3 )), dtxdecay);
365+ if (next_state->type != ssm->type ) {
366+ next_state = ggml_cast (ctx, next_state, ssm->type );
367+ }
365368 cb (next_state, " next_state" , il);
366369
367370 // TODO: Skip y and state updates if no previous state
@@ -395,6 +398,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
395398 }
396399
397400 // Concat the output y and state
401+ if (ssm->type != y->type ) {
402+ ssm = ggml_cast (ctx, ssm, y->type );
403+ }
398404 ggml_tensor * out = ggml_concat (ctx,
399405 ggml_view_1d (ctx, y, ggml_nelements (y), 0 ),
400406 ggml_view_1d (ctx, ssm, ggml_nelements (ssm), 0 ),
0 commit comments