Skip to content

Commit da8a338

Browse files
Merge branch 'add-fh1-rebased' of https://github.com/tiiuae/llama.cpp-public into add-fh1-rebased
2 parents 67b2664 + 7d7da0b commit da8a338

File tree

6 files changed

+18
-26
lines changed

6 files changed

+18
-26
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6674,7 +6674,8 @@ def set_gguf_parameters(self):
66746674

66756675
# Add Falcon Mamba2 specific configuration
66766676
self.gguf_writer.add_uint32("falcon_h1.attention.head_dim", self.hparams["head_dim"])
6677-
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"])
6677+
self.gguf_writer.add_uint32("falcon_h1.ssm.mamba_d_inner", self.hparams["mamba_d_ssm"])
6678+
self.gguf_writer.add_ssm_inner_size(self.hparams["mamba_d_ssm"])
66786679
self.gguf_writer.add_uint32("falcon_h1.num_attention_heads", self.find_hparam(["num_attention_heads"]))
66796680
self.gguf_writer.add_uint32("falcon_h1.num_key_value_heads",
66806681
self.find_hparam(["num_key_value_heads"], optional=True) or

src/llama-arch.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
219219
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
220220

221221
{ LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" },
222-
{ LLM_KV_MAMBA_D_SSM, "%s.ssm.mamba_d_ssm" },
223222

224223
{ LLM_KV_FALCON_H1_MAMBA_RMS_NORM, "%s.mamba_rms_norm" },
225224

src/llama-arch.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ enum llm_kv {
160160
// Falcon-H1 specific
161161
LLM_KV_ATTN_HEAD_DIM,
162162
LLM_KV_SSM_HEAD_DIM,
163-
LLM_KV_MAMBA_D_SSM,
164163
LLM_KV_N_LAYER,
165164
LLM_KV_FALCON_H1_MAMBA_RMS_NORM,
166165

src/llama-hparams.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,7 @@ uint32_t llama_hparams::n_embd_r() const {
7676
// Corresponds to Mamba's conv_states size
7777

7878
// check if the architecture is using d_ssm
79-
if (ssm_mamba_d_ssm > 0) {
80-
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_mamba_d_ssm + 2*ssm_n_group*ssm_d_state);
81-
} else {
82-
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
83-
}
84-
79+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
8580
}
8681

8782
uint32_t llama_hparams::n_embd_s() const {
@@ -91,7 +86,7 @@ uint32_t llama_hparams::n_embd_s() const {
9186
}
9287

9388
// corresponds to Mamba's ssm_states size
94-
return (ssm_mamba_d_ssm > 0 ? ssm_d_state * ssm_mamba_d_ssm : ssm_d_state * ssm_d_inner);
89+
return ssm_d_state * ssm_d_inner;
9590
}
9691

9792
bool llama_hparams::is_recurrent(uint32_t il) const {

src/llama-hparams.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ struct llama_hparams {
116116
uint32_t ssm_dt_rank = 0;
117117
uint32_t ssm_n_group = 0;
118118
uint32_t ssm_head_dim = 0;
119-
uint32_t ssm_mamba_d_ssm = 0;
120119

121120
// for hybrid state space models
122121
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

src/llama-model.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15551555
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15561556

15571557
// SSM parameters
1558-
ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm);
15591558
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
15601559
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
15611560
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
@@ -4514,7 +4513,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45144513
const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
45154514
const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
45164515
const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
4517-
const int64_t ssm_mamba_d_ssm = hparams.ssm_mamba_d_ssm;
4516+
const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand
45184517
const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
45194518
const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size;
45204519
const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads;
@@ -14768,10 +14767,10 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1476814767
const auto kv_head = kv_state->get_head();
1476914768

1477014769
const int64_t d_conv = hparams.ssm_d_conv;
14771-
const int64_t d_ssm = hparams.ssm_mamba_d_ssm;
14770+
const int64_t d_inner = hparams.ssm_d_inner;
1477214771
const int64_t d_state = hparams.ssm_d_state;
1477314772
const int64_t n_head = hparams.ssm_dt_rank;
14774-
const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_ssm / n_head : hparams.ssm_head_dim;
14773+
const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim;
1477514774
const int64_t n_group = hparams.ssm_n_group;
1477614775
const int64_t n_seqs = ubatch.n_seqs;
1477714776

@@ -14785,7 +14784,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1478514784
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
1478614785

1478714786
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
14788-
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs);
14787+
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
1478914788

1479014789
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
1479114790
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@@ -14798,22 +14797,22 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1479814797

1479914798
// split the above in three
1480014799
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
14801-
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt));
14802-
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt));
14800+
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
14801+
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
1480314802

1480414803
// conv
1480514804
{
1480614805
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
1480714806
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
1480814807

1480914808
// copy last (d_conv - 1) columns back into the state cache
14810-
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
14809+
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
1481114810

1481214811
ggml_build_forward_expand(gf,
1481314812
ggml_cpy(ctx0, last_conv,
1481414813
ggml_view_1d(ctx0, conv_states_all,
14815-
(d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs),
14816-
kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
14814+
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
14815+
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
1481714816

1481814817
// 1D convolution
1481914818
// The equivalent is to make a self-overlapping view of conv_x
@@ -14837,9 +14836,9 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1483714836
// These correspond to V K Q in SSM/attention duality
1483814837
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
1483914838

14840-
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm*ggml_element_size(xBC));
14839+
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
1484114840

14842-
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC));
14841+
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
1484314842

1484414843
// {n_head, n_seq_tokens, n_seqs}
1484514844
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
@@ -14862,8 +14861,8 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1486214861
// store last states
1486314862
ggml_build_forward_expand(gf,
1486414863
ggml_cpy(ctx0,
14865-
ggml_view_1d(ctx0, y_ssm, d_state*d_ssm*n_seqs, ggml_nelements(x)*x->nb[0]),
14866-
ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm*n_seqs, kv_head*d_state*d_ssm*ggml_element_size(ssm_states_all))));
14864+
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
14865+
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
1486714866

1486814867
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
1486914868

@@ -14878,7 +14877,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1487814877
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
1487914878
}
1488014879

14881-
y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs);
14880+
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
1488214881

1488314882
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
1488414883
cur = build_lora_mm(model.layers[il].ssm_out, y);

0 commit comments

Comments
 (0)