@@ -8229,11 +8229,13 @@ struct llm_build_plamo2 : public llm_graph_context {
82298229 for (int il = 0; il < n_layer; ++il) {
82308230 ggml_tensor * residual = inpL;
82318231
8232- ggml_graph_add_node(gf, model.layers[il].attn_norm);
8233- cb(model.layers[il].attn_norm, "attn_norm", il);
8232+ // ggml_graph_add_node(gf, model.layers[il].attn_norm);
8233+ // cb(model.layers[il].attn_norm, "attn_norm", il);
82348234
82358235 // pre_mixer_norm
8236- cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8236+ // cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8237+ cur = ggml_rms_norm(ctx0, inpL, hparams.f_norm_rms_eps);
8238+ cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
82378239 cb(cur, "attn_pre_norm", il);
82388240
82398241 // check if this layer is Mamba or Attention
@@ -8408,7 +8410,6 @@ struct llm_build_plamo2 : public llm_graph_context {
84088410 const int64_t d_inner = hparams.ssm_d_inner;
84098411 const int64_t d_state = hparams.ssm_d_state;
84108412 const int64_t n_seqs = ubatch.n_seqs;
8411- const float norm_rms_eps = hparams.f_norm_rms_eps;
84128413
84138414 const int64_t n_seq_tokens = ubatch.n_seq_tokens;
84148415
@@ -8481,17 +8482,17 @@ struct llm_build_plamo2 : public llm_graph_context {
84818482
84828483 // split into dt, B, C
84838484 const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
8484- ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8485- ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state );
8486- ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2* d_state));
8485+ ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8486+ ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*dt_dim );
8487+ ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(dt_dim + d_state));
84878488 cb(B, "mamba_B_raw", il);
84888489 cb(C, "mamba_C_raw", il);
84898490 cb(dt, "mamba_dt_raw", il);
84908491
84918492 // Apply RMS norm to dt, B, C (PLaMo-2 specific)
8492- B = ggml_rms_norm(ctx0, B, norm_rms_eps );
8493- C = ggml_rms_norm(ctx0, C, norm_rms_eps );
8494- dt = ggml_rms_norm(ctx0, dt, norm_rms_eps );
8493+ B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il );
8494+ C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il );
8495+ dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il );
84958496 cb(B, "mamba_B_normed", il);
84968497 cb(C, "mamba_C_normed", il);
84978498 cb(dt, "mamba_dt_normed", il);
0 commit comments