@@ -10536,7 +10536,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1053610536 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
1053710537 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
1053810538
10539- Qcur = build_norm(Qcur, model.layers[il].wq, NULL, LLM_NORM_RMS, il);
10539+ ggml_tensor * wq = ggml_cast(ctx0, model.layers[il].wq, Qcur->type);
10540+ Qcur = build_norm(Qcur, wq, NULL, LLM_NORM_RMS, il);
1054010541 cb(Qcur, "Qcur_normed", il);
1054110542
1054210543 Qcur = ggml_rope_ext(
@@ -10545,7 +10546,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1054510546 ext_factor, attn_factor, beta_fast, beta_slow
1054610547 );
1054710548
10548- Kcur = build_norm(Kcur, model.layers[il].wk, NULL, LLM_NORM_RMS, il);
10549+ ggml_tensor * wk = ggml_cast(ctx0, model.layers[il].wk, Kcur->type);
10550+ Kcur = build_norm(Kcur, wk, NULL, LLM_NORM_RMS, il);
1054910551 cb(Kcur, "Kcur_normed", il);
1055010552
1055110553 Kcur = ggml_rope_ext(
0 commit comments