@@ -1555,7 +1555,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15551555 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15561556
15571557 // SSM parameters
1558- ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm);
15591558 ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
15601559 ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
15611560 ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
@@ -4514,7 +4513,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45144513 const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
45154514 const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
45164515 const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
4517- const int64_t ssm_mamba_d_ssm = hparams.ssm_mamba_d_ssm;
4516+ const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand
45184517 const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
45194518 const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size;
45204519 const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads;
@@ -14768,10 +14767,10 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1476814767 const auto kv_head = kv_state->get_head();
1476914768
1477014769 const int64_t d_conv = hparams.ssm_d_conv;
14771- const int64_t d_ssm = hparams.ssm_mamba_d_ssm ;
14770+ const int64_t d_inner = hparams.ssm_d_inner ;
1477214771 const int64_t d_state = hparams.ssm_d_state;
1477314772 const int64_t n_head = hparams.ssm_dt_rank;
14774- const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_ssm / n_head : hparams.ssm_head_dim;
14773+ const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim;
1477514774 const int64_t n_group = hparams.ssm_n_group;
1477614775 const int64_t n_seqs = ubatch.n_seqs;
1477714776
@@ -14785,7 +14784,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1478514784 ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
1478614785
1478714786 ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
14788- conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs);
14787+ conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
1478914788
1479014789 // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
1479114790 cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -14798,22 +14797,22 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1479814797
1479914798 // split the above in three
1480014799 ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
14801- ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm *ggml_element_size(zxBCdt));
14802- ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt));
14800+ ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner *ggml_element_size(zxBCdt));
14801+ ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
1480314802
1480414803 // conv
1480514804 {
1480614805 // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
1480714806 ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
1480814807
1480914808 // copy last (d_conv - 1) columns back into the state cache
14810- ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
14809+ ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
1481114810
1481214811 ggml_build_forward_expand(gf,
1481314812 ggml_cpy(ctx0, last_conv,
1481414813 ggml_view_1d(ctx0, conv_states_all,
14815- (d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs),
14816- kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
14814+ (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
14815+ kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
1481714816
1481814817 // 1D convolution
1481914818 // The equivalent is to make a self-overlapping view of conv_x
@@ -14837,9 +14836,9 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1483714836 // These correspond to V K Q in SSM/attention duality
1483814837 ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
1483914838
14840- ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm *ggml_element_size(xBC));
14839+ ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner *ggml_element_size(xBC));
1484114840
14842- ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC));
14841+ ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
1484314842
1484414843 // {n_head, n_seq_tokens, n_seqs}
1484514844 dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
@@ -14862,8 +14861,8 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1486214861 // store last states
1486314862 ggml_build_forward_expand(gf,
1486414863 ggml_cpy(ctx0,
14865- ggml_view_1d(ctx0, y_ssm, d_state*d_ssm *n_seqs, ggml_nelements(x)*x->nb[0]),
14866- ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm *n_seqs, kv_head*d_state*d_ssm *ggml_element_size(ssm_states_all))));
14864+ ggml_view_1d(ctx0, y_ssm, d_state*d_inner *n_seqs, ggml_nelements(x)*x->nb[0]),
14865+ ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner *n_seqs, kv_head*d_state*d_inner *ggml_element_size(ssm_states_all))));
1486714866
1486814867 ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
1486914868
@@ -14878,7 +14877,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1487814877 y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
1487914878 }
1488014879
14881- y = ggml_reshape_3d(ctx0, y, d_ssm , n_seq_tokens, n_seqs);
14880+ y = ggml_reshape_3d(ctx0, y, d_inner , n_seq_tokens, n_seqs);
1488214881
1488314882 // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
1488414883 cur = build_lora_mm(model.layers[il].ssm_out, y);
0 commit comments