Skip to content

Commit 07b55f4

Browse files
committed
Move shared parameter definitions to the outside of loop
1 parent 90fbf6a commit 07b55f4

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/llama-model.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3367,6 +3367,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33673367
} break;
33683368
case LLM_ARCH_PLAMO2:
33693369
{
3370+
// mamba parameters
3371+
const uint32_t d_conv = hparams.ssm_d_conv;
3372+
const uint32_t d_state = hparams.ssm_d_state;
3373+
const uint32_t num_heads = hparams.ssm_dt_rank;
3374+
const uint32_t intermediate_size = hparams.ssm_d_inner;
3375+
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3376+
3377+
// attention parameters
3378+
const uint32_t qk_dim = hparams.n_embd_head_k;
3379+
const uint32_t v_dim = hparams.n_embd_head_v;
3380+
33703381
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
33713382

33723383
// output
@@ -3381,16 +3392,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33813392
auto & layer = layers[i];
33823393
bool is_mamba_layer = hparams.is_recurrent(i);
33833394

3384-
33853395
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
33863396

33873397
if (is_mamba_layer) {
3388-
const uint32_t d_conv = hparams.ssm_d_conv;
3389-
const uint32_t d_state = hparams.ssm_d_state;
3390-
const uint32_t num_heads = hparams.ssm_dt_rank;
3391-
const uint32_t intermediate_size = hparams.ssm_d_inner;
3392-
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
3393-
33943398
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0);
33953399
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
33963400

@@ -3407,9 +3411,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34073411
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
34083412
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
34093413
} else {
3410-
const uint32_t head_dim = hparams.n_embd_head_k;
3411-
const uint32_t qk_dim = head_dim;
3412-
const uint32_t v_dim = head_dim;
34133414
const int64_t num_attention_heads = hparams.n_head(i);
34143415
const int64_t q_num_heads = num_attention_heads;
34153416
const int64_t num_key_value_heads = hparams.n_head_kv(i);
@@ -3420,8 +3421,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
34203421
const int64_t v_proj_dim = v_num_heads * v_dim;
34213422

34223423
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
3423-
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
3424-
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
3424+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0);
3425+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0);
34253426
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
34263427
}
34273428

0 commit comments

Comments
 (0)