@@ -1552,7 +1552,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15521552 case LLM_ARCH_FALCON_H1:
15531553 {
15541554 // Common parameters
1555- ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size);
15561555 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15571556
15581557 // SSM parameters
@@ -1564,10 +1563,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15641563 ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
15651564 ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
15661565
1567- // Falcon-H1 parameters
1568- ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim);
1569- ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm);
1570-
15711566 std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
15721567
15731568 switch (hparams.n_layer) {
@@ -4514,31 +4509,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45144509 {
45154510 // Common
45164511 const int64_t hidden_size = hparams.n_embd; // hidden_size
4517- const int64_t vocab_size = hparams.vocab_size; // vocab_size
45184512
45194513 // mamba2 Mixer SSM params
45204514 const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
45214515 const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
45224516 const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
4523- const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int( hparams.mamba_expand * hidden_size); // TODO expand
4517+ const int64_t ssm_mamba_d_ssm = hparams.ssm_mamba_d_ssm;
45244518 const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
4525- const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size;
4526- const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads;
4519+ const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size;
4520+ const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads;
45274521
45284522 // attn params
45294523 const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head
45304524 const int64_t attn_num_key_value_head = hparams.n_head_kv(0);
4531- const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head;
45324525
45334526 // ffn params
45344527 const int64_t ffn_intermediate_size = hparams.n_ff(0);
45354528
45364529 // embeddings
4537- tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size }, 0);
4530+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab }, 0);
45384531
45394532 // output
45404533 {
4541- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size }, TENSOR_NOT_REQUIRED);
4534+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab }, TENSOR_NOT_REQUIRED);
45424535 final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0);
45434536 }
45444537
@@ -4558,21 +4551,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45584551 layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
45594552 layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
45604553 // ssm_norm
4561- if (hparams.mamba_rms_norm == true) {
4562- layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0);
4563- }
4554+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_mamba_d_ssm / ssm_n_groups, ssm_n_groups}, 0);
45644555 // out_proj
4565- layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size , hidden_size}, 0);
4556+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_mamba_d_ssm , hidden_size}, 0);
45664557
45674558 /*ATTENTION LAYERS*/
45684559 // attention layers (with optional bias)
4569- layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0);
4570- layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim }, 0);
4571- layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim }, 0);
4572- layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0);
4560+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0);
4561+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k }, 0);
4562+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v }, 0);
4563+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0);
45734564 layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
4574- layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim }, llama_model_loader::TENSOR_NOT_REQUIRED);
4575- layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim }, llama_model_loader::TENSOR_NOT_REQUIRED);
4565+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k }, llama_model_loader::TENSOR_NOT_REQUIRED);
4566+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v }, llama_model_loader::TENSOR_NOT_REQUIRED);
45764567 layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
45774568 layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0);
45784569
@@ -14717,7 +14708,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1471714708 inpSA = ggml_add(ctx0, cur, inpSA);
1471814709 cb(cur, "layer_out", il);
1471914710
14720- if (il == n_layer - 1) {
14711+ if (il == n_layer - 1 && inp_out_ids ) {
1472114712 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1472214713 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1472314714 }
@@ -14882,7 +14873,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1488214873 y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
1488314874
1488414875 // grouped RMS norm
14885- if (hparams.mamba_rms_norm) {
14876+ if (model.layers[il].ssm_norm) {
1488614877 y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs);
1488714878 y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
1488814879 }
0 commit comments