@@ -3367,15 +3367,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33673367 } break;
33683368 case LLM_ARCH_PLAMO2:
33693369 {
3370- const uint32_t d_conv = hparams.ssm_d_conv;
3371- const uint32_t d_state = hparams.ssm_d_state;
3372- const uint32_t num_heads = hparams.ssm_dt_rank;
3373- const uint32_t intermediate_size = hparams.ssm_d_inner;
3374- const uint32_t head_dim = hparams.wkv_head_size;
3375- const uint32_t qk_dim = head_dim;
3376- const uint32_t v_dim = head_dim;
3377- const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3378-
33793370 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33803371
33813372 // output
@@ -3390,12 +3381,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33903381 auto & layer = layers[i];
33913382 bool is_mamba_layer = hparams.is_recurrent(i);
33923383
3393- const int64_t num_attention_heads = hparams.n_head_kv_arr[i];
3394- const int64_t q_num_heads = num_attention_heads;
33953384
33963385 layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
33973386
33983387 if (is_mamba_layer) {
3388+ const uint32_t d_conv = hparams.ssm_d_conv;
3389+ const uint32_t d_state = hparams.ssm_d_state;
3390+ const uint32_t num_heads = hparams.ssm_dt_rank;
3391+ const uint32_t intermediate_size = hparams.ssm_d_inner;
3392+ const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3393+
33993394 layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0);
34003395 layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
34013396
@@ -3412,6 +3407,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34123407 layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
34133408 layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
34143409 } else {
3410+ const uint32_t head_dim = hparams.n_embd_head_k;
3411+ const uint32_t qk_dim = head_dim;
3412+ const uint32_t v_dim = head_dim;
3413+ const int64_t num_attention_heads = hparams.n_head(i);
3414+ const int64_t q_num_heads = num_attention_heads;
34153415 const int64_t num_key_value_heads = hparams.n_head_kv(i);
34163416 const int64_t k_num_heads = num_key_value_heads;
34173417 const int64_t v_num_heads = num_key_value_heads;
0 commit comments