Skip to content

Commit 4fe5c96

Browse files
committed
Fix num heads
1 parent 433782b commit 4fe5c96

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

convert_hf_to_gguf.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4185,7 +4185,8 @@ def set_gguf_parameters(self):
41854185
# This logic matches modeling_plamo.py's is_mamba function
41864186
mamba_step = hparams.get("mamba_step", 2)
41874187
mamba_enabled = hparams.get("mamba_enabled", True)
4188-
mamba_layers = []
4188+
num_key_value_heads = []
4189+
num_attention_heads = []
41894190

41904191
if mamba_enabled:
41914192
for i in range(block_count):
@@ -4195,18 +4196,20 @@ def set_gguf_parameters(self):
41954196
else:
41964197
is_mamba = (i % mamba_step) != (mamba_step // 2)
41974198
if is_mamba:
4198-
mamba_layers.append(0)
4199+
num_key_value_heads.append(0)
41994200
else:
4200-
mamba_layers.append(hparams.get("num_key_value_heads", 4))
4201+
num_key_value_heads.append(hparams.get("num_key_value_heads", 4))
4202+
num_attention_heads.append(hparams.get("num_attention_heads", 32))
42014203

4202-
if mamba_layers:
4203-
self.gguf_writer.add_head_count_kv(mamba_layers)
4204+
if num_key_value_heads and num_attention_heads:
4205+
self.gguf_writer.add_head_count_kv(num_key_value_heads)
4206+
self.gguf_writer.add_head_count(num_attention_heads)
42044207

42054208
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
42064209
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
4207-
self.gguf_writer.add_features_length(hparams.get("hidden_size_per_head", 128))
4210+
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
4211+
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
42084212
self.gguf_writer.add_block_count(block_count)
4209-
self.gguf_writer.add_wkv_head_size(hparams.get("num_attention_heads", 32))
42104213
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
42114214
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
42124215

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct llama_hparams {
4242
uint32_t n_embd;
4343
uint32_t n_embd_features = 0;
4444
uint32_t n_layer;
45-
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
45+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
4646
uint32_t n_rot;
4747
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4848
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head

src/llama-model.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3367,15 +3367,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33673367
} break;
33683368
case LLM_ARCH_PLAMO2:
33693369
{
3370-
const uint32_t d_conv = hparams.ssm_d_conv;
3371-
const uint32_t d_state = hparams.ssm_d_state;
3372-
const uint32_t num_heads = hparams.ssm_dt_rank;
3373-
const uint32_t intermediate_size = hparams.ssm_d_inner;
3374-
const uint32_t head_dim = hparams.wkv_head_size;
3375-
const uint32_t qk_dim = head_dim;
3376-
const uint32_t v_dim = head_dim;
3377-
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3378-
33793370
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33803371

33813372
// output
@@ -3390,12 +3381,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33903381
auto & layer = layers[i];
33913382
bool is_mamba_layer = hparams.is_recurrent(i);
33923383

3393-
const int64_t num_attention_heads = hparams.n_head_kv_arr[i];
3394-
const int64_t q_num_heads = num_attention_heads;
33953384

33963385
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
33973386

33983387
if (is_mamba_layer) {
3388+
const uint32_t d_conv = hparams.ssm_d_conv;
3389+
const uint32_t d_state = hparams.ssm_d_state;
3390+
const uint32_t num_heads = hparams.ssm_dt_rank;
3391+
const uint32_t intermediate_size = hparams.ssm_d_inner;
3392+
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3393+
33993394
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0);
34003395
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
34013396

@@ -3412,6 +3407,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34123407
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
34133408
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
34143409
} else {
3410+
const uint32_t head_dim = hparams.n_embd_head_k;
3411+
const uint32_t qk_dim = head_dim;
3412+
const uint32_t v_dim = head_dim;
3413+
const int64_t num_attention_heads = hparams.n_head(i);
3414+
const int64_t q_num_heads = num_attention_heads;
34153415
const int64_t num_key_value_heads = hparams.n_head_kv(i);
34163416
const int64_t k_num_heads = num_key_value_heads;
34173417
const int64_t v_num_heads = num_key_value_heads;

0 commit comments

Comments
 (0)