@@ -1084,7 +1084,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1084
1084
}
1085
1085
break;
1086
1086
default: type = LLM_TYPE_UNKNOWN;
1087
- }
1087
+ }
1088
+
1089
+ // Load attention parameters
1090
+ ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
1091
+ ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
1088
1092
} break;
1089
1093
case LLM_ARCH_GPT2:
1090
1094
{
@@ -3392,17 +3396,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3392
3396
} break;
3393
3397
case LLM_ARCH_PLAMO2:
3394
3398
{
3399
+ // mamba parameters
3395
3400
const uint32_t d_conv = hparams.ssm_d_conv;
3396
3401
const uint32_t d_state = hparams.ssm_d_state;
3397
3402
const uint32_t num_heads = hparams.ssm_dt_rank;
3398
3403
const uint32_t intermediate_size = hparams.ssm_d_inner;
3399
- const uint32_t head_dim = intermediate_size / num_heads;
3400
- const uint32_t qk_dim = head_dim;
3401
- const uint32_t v_dim = head_dim;
3402
- const int64_t num_attention_heads = hparams.n_head();
3403
- const int64_t q_num_heads = num_attention_heads;
3404
3404
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3405
3405
3406
+ // attention parameters
3407
+ const uint32_t qk_dim = hparams.n_embd_head_k;
3408
+ const uint32_t v_dim = hparams.n_embd_head_v;
3409
+
3406
3410
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3407
3411
3408
3412
// output
@@ -3436,6 +3440,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3436
3440
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
3437
3441
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
3438
3442
} else {
3443
+ const int64_t num_attention_heads = hparams.n_head(i);
3444
+ const int64_t q_num_heads = num_attention_heads;
3439
3445
const int64_t num_key_value_heads = hparams.n_head_kv(i);
3440
3446
const int64_t k_num_heads = num_key_value_heads;
3441
3447
const int64_t v_num_heads = num_key_value_heads;
@@ -3444,8 +3450,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
3444
3450
const int64_t v_proj_dim = v_num_heads * v_dim;
3445
3451
3446
3452
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
3447
- layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim , num_attention_heads}, 0);
3448
- layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim , k_num_heads}, 0);
3453
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim , num_attention_heads}, 0);
3454
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim , k_num_heads}, 0);
3449
3455
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
3450
3456
}
3451
3457
@@ -17611,6 +17617,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
17611
17617
const int64_t n_embd_head_q = hparams.n_embd_head_k;
17612
17618
const int64_t n_embd_head_k = hparams.n_embd_head_k;
17613
17619
const int64_t n_embd_head_v = hparams.n_embd_head_v;
17620
+ int32_t n_head = hparams.n_head(il);
17614
17621
int32_t n_head_kv = hparams.n_head_kv(il);
17615
17622
17616
17623
const int64_t q_offset = 0;
0 commit comments