@@ -8027,13 +8027,8 @@ struct llm_build_plamo2 : public llm_graph_context {
80278027 // ggml_graph_add_node(gf, model.layers[il].attn_norm);
80288028 // cb(model.layers[il].attn_norm, "attn_norm", il);
80298029
8030- ggml_graph_add_node(gf, model.layers[il].attn_norm);
8031- cb(model.layers[il].attn_norm, "attn_norm_weight", il);
8032-
80338030 // pre_mixer_norm
8034- cb(inpL, "attn_pre_norm_input", il);
80358031 cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8036- cb(cur, "attn_pre_norm", il);
80378032
80388033 // check if this layer is Mamba or Attention
80398034 bool is_mamba_layer = hparams.is_recurrent(il);
@@ -8073,6 +8068,10 @@ struct llm_build_plamo2 : public llm_graph_context {
80738068 cur = ggml_add(ctx0, cur, residual);
80748069
80758070 inpL = cur;
8071+
8072+ if (il >= 2) {
8073+ break;
8074+ }
80768075 }
80778076
80788077 cur = inpL;
@@ -8238,17 +8237,28 @@ struct llm_build_plamo2 : public llm_graph_context {
82388237 ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
82398238 cb(zx, "mamba_in_proj", il);
82408239
8240+ zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
8241+ zx = ggml_reshape_4d(ctx0, zx, 2 * hparams.ssm_head_dim, hparams.ssm_num_heads, n_seq_tokens, n_seqs);
8242+ cb(zx, "mamba_in_proj_out", il);
8243+
82418244 // split into z and x
82428245 // => {d_inner, n_seq_tokens, n_seqs}
8243- ggml_tensor * x = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], 0);
8244- ggml_tensor * z = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], d_inner*ggml_element_size(zx));
8246+ ggml_tensor * x = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], hparams.ssm_head_dim*ggml_element_size(zx));
8247+ x = ggml_cont(ctx0, x);
8248+ x = ggml_reshape_4d(ctx0, x, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8249+ x = ggml_permute(ctx0, x, 0, 2, 1, 3);
82458250 cb(x, "mamba_x_split", il);
8251+ ggml_tensor * z = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], 0);
8252+ z = ggml_cont(ctx0, z);
8253+ z = ggml_reshape_4d(ctx0, z, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8254+ z = ggml_permute(ctx0, z, 0, 2, 1, 3);
82468255 cb(z, "mamba_z_split", il);
82478256
82488257 // conv1d
82498258 {
82508259 // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
82518260 ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
8261+ cb(conv_x, "mamba_conv1d_input", il);
82528262
82538263 // copy last (d_conv - 1) columns back into the state cache
82548264 ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
@@ -8264,9 +8274,6 @@ struct llm_build_plamo2 : public llm_graph_context {
82648274 x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
82658275 cb(x, "mamba_conv1d", il);
82668276
8267- // bias
8268- // x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); // PLaMo-2 does not use bias here
8269-
82708277 x = ggml_silu(ctx0, x);
82718278 cb(x, "mamba_conv1d_silu", il);
82728279 }
@@ -8279,9 +8286,9 @@ struct llm_build_plamo2 : public llm_graph_context {
82798286
82808287 // split into dt, B, C
82818288 const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
8282- ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8283- ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*dt_dim );
8284- ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(dt_dim + d_state));
8289+ ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8290+ ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state );
8291+ ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2* d_state));
82858292 cb(B, "mamba_B_raw", il);
82868293 cb(C, "mamba_C_raw", il);
82878294 cb(dt, "mamba_dt_raw", il);
@@ -8296,15 +8303,17 @@ struct llm_build_plamo2 : public llm_graph_context {
82968303
82978304 // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
82988305 dt = build_lora_mm(model.layers[il].ssm_dt, dt);
8299- dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
83008306 cb(dt, "mamba_dt_proj", il);
83018307
83028308 // This is corresponding to the broadcast_to operation in ssd_update_state() of the originall code
8303- ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, d_inner, n_seq_tokens );
8309+ ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, dt_dim * hparams.ssm_num_heads, dt->ne[1] );
83048310 dt = ggml_repeat(ctx0, dt, dt_expanded);
8311+ cb(dt, "mamba_dt_expanded", il);
8312+
83058313 ggml_tensor * A_expanded = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, d_inner);
83068314 A_expanded = ggml_repeat(ctx0, model.layers[il].ssm_a, A_expanded);
8307- cb(dt, "mamba_dt_expanded", il);
8315+ A_expanded = ggml_exp(ctx0, A_expanded);
8316+ A_expanded = ggml_scale(ctx0, A_expanded, -1.0f);
83088317 cb(A_expanded, "mamba_A_expanded", il);
83098318
83108319 // SSM scan operation
0 commit comments