@@ -4626,16 +4626,6 @@ static void llm_load_hparams(
46264626
46274627    // non-transformer models do not have attention heads
46284628    if (hparams.n_head() > 0) {
4629-         // sanity check for n_rot (optional)
4630-         hparams.n_rot = hparams.n_embd / hparams.n_head();
4631- 
4632-         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4633- 
4634-         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4635-             if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4636-                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4637-             }
4638-         }
46394629        // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
46404630        // gpt-j n_rot = rotary_dim
46414631
@@ -4644,6 +4634,17 @@ static void llm_load_hparams(
46444634
46454635        hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
46464636        ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4637+ 
4638+         // sanity check for n_rot (optional)
4639+         hparams.n_rot = hparams.n_embd_head_k;
4640+ 
4641+         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4642+ 
4643+         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4644+             if (hparams.n_rot != hparams.n_embd_head_k) {
4645+                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4646+             }
4647+         }
46474648    } else {
46484649        hparams.n_rot = 0;
46494650        hparams.n_embd_head_k = 0;
@@ -11491,7 +11492,7 @@ struct llm_build_context {
1149111492
1149211493                Qcur = ggml_rope_ext(
1149311494                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
11494-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11495+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1149511496                        ext_factor, attn_factor, beta_fast, beta_slow);
1149611497                cb(Qcur, "Qcur", il);
1149711498
@@ -11500,7 +11501,7 @@ struct llm_build_context {
1150011501
1150111502                Kcur = ggml_rope_ext(
1150211503                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11503-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11504+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1150411505                        ext_factor, attn_factor, beta_fast, beta_slow);
1150511506                cb(Kcur, "Kcur", il);
1150611507
@@ -11604,7 +11605,7 @@ struct llm_build_context {
1160411605
1160511606                Qcur = ggml_rope_ext(
1160611607                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
11607-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11608+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1160811609                        ext_factor, attn_factor, beta_fast, beta_slow);
1160911610                cb(Qcur, "Qcur", il);
1161011611
@@ -11613,7 +11614,7 @@ struct llm_build_context {
1161311614
1161411615                Kcur = ggml_rope_ext(
1161511616                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11616-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11617+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1161711618                        ext_factor, attn_factor, beta_fast, beta_slow);
1161811619                cb(Kcur, "Kcur", il);
1161911620
0 commit comments