@@ -8234,13 +8234,8 @@ struct llm_build_plamo2 : public llm_graph_context {
82348234 // ggml_graph_add_node(gf, model.layers[il].attn_norm);
82358235 // cb(model.layers[il].attn_norm, "attn_norm", il);
82368236
8237- ggml_graph_add_node(gf, model.layers[il].attn_norm);
8238- cb(model.layers[il].attn_norm, "attn_norm_weight", il);
8239-
82408237 // pre_mixer_norm
8241- cb(inpL, "attn_pre_norm_input", il);
82428238 cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
8243- cb(cur, "attn_pre_norm", il);
82448239
82458240 // check if this layer is Mamba or Attention
82468241 bool is_mamba_layer = hparams.is_recurrent(il);
@@ -8280,6 +8275,10 @@ struct llm_build_plamo2 : public llm_graph_context {
82808275 cur = ggml_add(ctx0, cur, residual);
82818276
82828277 inpL = cur;
8278+
8279+ if (il >= 2) {
8280+ break;
8281+ }
82838282 }
82848283
82858284 cur = inpL;
@@ -8445,17 +8444,28 @@ struct llm_build_plamo2 : public llm_graph_context {
84458444 ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
84468445 cb(zx, "mamba_in_proj", il);
84478446
8447+ zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
8448+ zx = ggml_reshape_4d(ctx0, zx, 2 * hparams.ssm_head_dim, hparams.ssm_num_heads, n_seq_tokens, n_seqs);
8449+ cb(zx, "mamba_in_proj_out", il);
8450+
84488451 // split into z and x
84498452 // => {d_inner, n_seq_tokens, n_seqs}
8450- ggml_tensor * x = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], 0);
8451- ggml_tensor * z = ggml_view_3d(ctx0, zx, d_inner, zx->ne[1], zx->ne[2], zx->nb[1], zx->nb[2], d_inner*ggml_element_size(zx));
8453+ ggml_tensor * x = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], hparams.ssm_head_dim*ggml_element_size(zx));
8454+ x = ggml_cont(ctx0, x);
8455+ x = ggml_reshape_4d(ctx0, x, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8456+ x = ggml_permute(ctx0, x, 0, 2, 1, 3);
84528457 cb(x, "mamba_x_split", il);
8458+ ggml_tensor * z = ggml_view_4d(ctx0, zx, hparams.ssm_head_dim, zx->ne[1], zx->ne[2], zx->ne[3], zx->nb[1], zx->nb[2], zx->nb[3], 0);
8459+ z = ggml_cont(ctx0, z);
8460+ z = ggml_reshape_4d(ctx0, z, hparams.ssm_head_dim * hparams.ssm_num_heads, 1, n_seq_tokens, n_seqs);
8461+ z = ggml_permute(ctx0, z, 0, 2, 1, 3);
84538462 cb(z, "mamba_z_split", il);
84548463
84558464 // conv1d
84568465 {
84578466 // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
84588467 ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
8468+ cb(conv_x, "mamba_conv1d_input", il);
84598469
84608470 // copy last (d_conv - 1) columns back into the state cache
84618471 ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
@@ -8471,9 +8481,6 @@ struct llm_build_plamo2 : public llm_graph_context {
84718481 x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
84728482 cb(x, "mamba_conv1d", il);
84738483
8474- // bias
8475- // x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); // PLaMo-2 does not use bias here
8476-
84778484 x = ggml_silu(ctx0, x);
84788485 cb(x, "mamba_conv1d_silu", il);
84798486 }
@@ -8486,9 +8493,9 @@ struct llm_build_plamo2 : public llm_graph_context {
84868493
84878494 // split into dt, B, C
84888495 const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
8489- ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8490- ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*dt_dim );
8491- ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(dt_dim + d_state));
8496+ ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
8497+ ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state );
8498+ ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim , n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2* d_state));
84928499 cb(B, "mamba_B_raw", il);
84938500 cb(C, "mamba_C_raw", il);
84948501 cb(dt, "mamba_dt_raw", il);
@@ -8503,15 +8510,17 @@ struct llm_build_plamo2 : public llm_graph_context {
85038510
85048511 // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
85058512 dt = build_lora_mm(model.layers[il].ssm_dt, dt);
8506- dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
85078513 cb(dt, "mamba_dt_proj", il);
85088514
85098515 // This is corresponding to the broadcast_to operation in ssd_update_state() of the originall code
8510- ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, d_inner, n_seq_tokens );
8516+ ggml_tensor * dt_expanded = ggml_new_tensor_2d(ctx0, dt->type, dt_dim * hparams.ssm_num_heads, dt->ne[1] );
85118517 dt = ggml_repeat(ctx0, dt, dt_expanded);
8518+ cb(dt, "mamba_dt_expanded", il);
8519+
85128520 ggml_tensor * A_expanded = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, d_inner);
85138521 A_expanded = ggml_repeat(ctx0, model.layers[il].ssm_a, A_expanded);
8514- cb(dt, "mamba_dt_expanded", il);
8522+ A_expanded = ggml_exp(ctx0, A_expanded);
8523+ A_expanded = ggml_scale(ctx0, A_expanded, -1.0f);
85158524 cb(A_expanded, "mamba_A_expanded", il);
85168525
85178526 // SSM scan operation
0 commit comments