Skip to content

Commit 67b2664

Browse files
cleaning unused hparams
1 parent d2f46f1 commit 67b2664

File tree

2 files changed

+15
-33
lines changed

2 files changed

+15
-33
lines changed

src/llama-hparams.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,6 @@ struct llama_hparams {
118118
uint32_t ssm_head_dim = 0;
119119
uint32_t ssm_mamba_d_ssm = 0;
120120

121-
uint32_t attn_head_dim = 0;
122-
bool mamba_rms_norm = false;
123-
uint32_t vocab_size = 0;
124-
uint32_t intermediate_size = 0;
125-
float mamba_expand = 0.0f;
126-
bool ssm_rms_norm = false;
127-
bool ssm_conv_bias = false;
128-
bool ssm_proj_bias = false;
129-
130121
// for hybrid state space models
131122
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
132123

src/llama-model.cpp

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15521552
case LLM_ARCH_FALCON_H1:
15531553
{
15541554
// Common parameters
1555-
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size);
15561555
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
15571556

15581557
// SSM parameters
@@ -1564,10 +1563,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15641563
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
15651564
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
15661565

1567-
// Falcon-H1 parameters
1568-
ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim);
1569-
ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm);
1570-
15711566
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
15721567

15731568
switch (hparams.n_layer) {
@@ -4514,31 +4509,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45144509
{
45154510
// Common
45164511
const int64_t hidden_size = hparams.n_embd; // hidden_size
4517-
const int64_t vocab_size = hparams.vocab_size; // vocab_size
45184512

45194513
// mamba2 Mixer SSM params
45204514
const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
45214515
const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
45224516
const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
4523-
const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand
4517+
const int64_t ssm_mamba_d_ssm = hparams.ssm_mamba_d_ssm;
45244518
const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
4525-
const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size;
4526-
const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads;
4519+
const int64_t ssm_conv_dim = ssm_mamba_d_ssm + 2 * ssm_n_groups * ssm_state_size;
4520+
const int64_t ssm_projection_size = ssm_mamba_d_ssm + ssm_conv_dim + ssm_num_heads;
45274521

45284522
// attn params
45294523
const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head
45304524
const int64_t attn_num_key_value_head = hparams.n_head_kv(0);
4531-
const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head;
45324525

45334526
// ffn params
45344527
const int64_t ffn_intermediate_size = hparams.n_ff(0);
45354528

45364529
// embeddings
4537-
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0);
4530+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0);
45384531

45394532
// output
45404533
{
4541-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED);
4534+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED);
45424535
final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0);
45434536
}
45444537

@@ -4558,21 +4551,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
45584551
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
45594552
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
45604553
// ssm_norm
4561-
if (hparams.mamba_rms_norm == true) {
4562-
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0);
4563-
}
4554+
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_mamba_d_ssm / ssm_n_groups, ssm_n_groups}, 0);
45644555
// out_proj
4565-
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0);
4556+
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_mamba_d_ssm, hidden_size}, 0);
45664557

45674558
/*ATTENTION LAYERS*/
45684559
// attention layers (with optional bias)
4569-
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0);
4570-
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0);
4571-
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0);
4572-
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0);
4560+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0);
4561+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0);
4562+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0);
4563+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0);
45734564
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
4574-
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED);
4575-
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED);
4565+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, llama_model_loader::TENSOR_NOT_REQUIRED);
4566+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, llama_model_loader::TENSOR_NOT_REQUIRED);
45764567
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
45774568
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0);
45784569

@@ -14717,7 +14708,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1471714708
inpSA = ggml_add(ctx0, cur, inpSA);
1471814709
cb(cur, "layer_out", il);
1471914710

14720-
if (il == n_layer - 1) {
14711+
if (il == n_layer - 1 && inp_out_ids) {
1472114712
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1472214713
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1472314714
}
@@ -14882,7 +14873,7 @@ struct llm_build_falcon_h1 : public llm_graph_context {
1488214873
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
1488314874

1488414875
// grouped RMS norm
14885-
if (hparams.mamba_rms_norm){
14876+
if (model.layers[il].ssm_norm) {
1488614877
y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs);
1488714878
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
1488814879
}

0 commit comments

Comments
 (0)