@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
47454745
47464746    // non-transformer models do not have attention heads
47474747    if (hparams.n_head() > 0) {
4748-         // sanity check for n_rot (optional)
4749-         hparams.n_rot = hparams.n_embd / hparams.n_head();
4750- 
4751-         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752- 
4753-         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754-             if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755-                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756-             }
4757-         }
47584748        // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
47594749        // gpt-j n_rot = rotary_dim
47604750
@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
47634753
47644754        hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
47654755        ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756+ 
4757+         // sanity check for n_rot (optional)
4758+         hparams.n_rot = hparams.n_embd_head_k;
4759+ 
4760+         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761+ 
4762+         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763+             if (hparams.n_rot != hparams.n_embd_head_k) {
4764+                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765+             }
4766+         }
47664767    } else {
47674768        hparams.n_rot = 0;
47684769        hparams.n_embd_head_k = 0;
@@ -11633,7 +11634,7 @@ struct llm_build_context {
1163311634
1163411635                Qcur = ggml_rope_ext(
1163511636                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
11636-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11637+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1163711638                        ext_factor, attn_factor, beta_fast, beta_slow);
1163811639                cb(Qcur, "Qcur", il);
1163911640
@@ -11642,7 +11643,7 @@ struct llm_build_context {
1164211643
1164311644                Kcur = ggml_rope_ext(
1164411645                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11645-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11646+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1164611647                        ext_factor, attn_factor, beta_fast, beta_slow);
1164711648                cb(Kcur, "Kcur", il);
1164811649
@@ -11746,7 +11747,7 @@ struct llm_build_context {
1174611747
1174711748                Qcur = ggml_rope_ext(
1174811749                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
11749-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11750+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1175011751                        ext_factor, attn_factor, beta_fast, beta_slow);
1175111752                cb(Qcur, "Qcur", il);
1175211753
@@ -11755,7 +11756,7 @@ struct llm_build_context {
1175511756
1175611757                Kcur = ggml_rope_ext(
1175711758                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11758-                         n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11759+                         n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1175911760                        ext_factor, attn_factor, beta_fast, beta_slow);
1176011761                cb(Kcur, "Kcur", il);
1176111762
0 commit comments