@@ -10466,20 +10466,14 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1046610466
1046710467 if (il == n_layer - 1 && inp_out_ids) {
1046810468 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10469- inpL = ggml_get_rows(ctx0, inpL , inp_out_ids);
10469+ residual = ggml_get_rows(ctx0, residual , inp_out_ids);
1047010470 }
1047110471
1047210472 // residual connection
1047310473 cur = ggml_add(ctx0, cur, residual);
1047410474 cb(cur, "ffn_residual", il);
1047510475
1047610476 inpL = cur;
10477-
10478- if (il == 1) {
10479- cur = ggml_get_rows(ctx0, cur, inp_out_ids);
10480- inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
10481- break;
10482- }
1048310477 }
1048410478
1048510479 cur = inpL;
@@ -10627,7 +10621,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1062710621 // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
1062810622 ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
1062910623 cb(zx, "mamba_in_proj", il);
10630-
10624+ // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
1063110625 zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
1063210626 zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
1063310627 cb(zx, "mamba_in_proj_out", il);
@@ -10636,14 +10630,11 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1063610630 // => {head_dim * n_heads, n_seq_tokens, n_seqs}
1063710631 ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
1063810632 x = ggml_cont(ctx0, x);
10639- x = ggml_reshape_4d (ctx0, x, head_dim * n_heads, 1 , n_seq_tokens, n_seqs);
10640- x = ggml_permute(ctx0, x, 0, 2, 1, 3);
10633+ x = ggml_reshape_3d (ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
10634+ // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
1064110635 cb(x, "mamba_x_split", il);
1064210636
1064310637 ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
10644- z = ggml_cont(ctx0, z);
10645- z = ggml_reshape_4d(ctx0, z, head_dim * n_heads, 1, n_seq_tokens, n_seqs);
10646- z = ggml_permute(ctx0, z, 0, 2, 1, 3);
1064710638 cb(z, "mamba_z_split", il);
1064810639
1064910640 // conv1d
@@ -10699,11 +10690,10 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1069910690 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
1070010691 cb(dt, "mamba_dt_proj", il);
1070110692
10702- ggml_tensor * A = ggml_new_tensor_2d(ctx0, model.layers[il].ssm_a->type, d_state, n_heads);
10703- A = ggml_repeat(ctx0, model.layers[il].ssm_a, A);
10693+ ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
1070410694 cb(A, "mamba_A", il);
1070510695
10706- x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0], x->nb[1], x->nb[2] , 0);
10696+ x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x) , 0);
1070710697 B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
1070810698 C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
1070910699
@@ -10725,22 +10715,22 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1072510715 // store last states
1072610716 ggml_build_forward_expand(gf,
1072710717 ggml_cpy(ctx0,
10728- ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
10718+ ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3] ),
1072910719 ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
1073010720 kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
1073110721
10732- ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * x->nb[0] , head_dim * n_heads * x->nb[1] , head_dim * n_heads * n_seq_tokens * x->nb[2] , 0);
10722+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x) , head_dim * n_heads * ggml_element_size(x) , head_dim * n_heads * n_seq_tokens * ggml_element_size(x) , 0);
1073310723 cb(y, "mamba_y_view", il);
1073410724
1073510725 // Add D parameter and apply gating with z
1073610726 // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
10737- y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
10738- cb(y, "mamba_y_with_D", il);
10727+ ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
10728+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
10729+ cb(y, "mamba_y_add_d", il);
1073910730
1074010731 ggml_tensor * z_silu = ggml_silu(ctx0, ggml_cont(ctx0, z));
1074110732 cb(z_silu, "mamba_z_silu", il);
1074210733
10743- y = ggml_reshape_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs);
1074410734 y = ggml_mul(ctx0, y, z_silu);
1074510735 cb(y, "mamba_y_gated", il);
1074610736
0 commit comments