Skip to content

Commit 992d4f0

Browse files
committed
Rope fixes.
1 parent bd27e81 commit 992d4f0

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,7 @@ def set_gguf_parameters(self):
28282828
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
28292829
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
28302830
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
2831+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
28312832
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
28322833
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
28332834
if (shared_expert_intermediate_size := self.hparams.get('intermediate_size')) is not None and (num_key_value_heads := self.hparams.get('num_key_value_heads')) is not None:

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,6 +1816,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
18161816
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
18171817
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
18181818
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1819+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
18191820
},
18201821
},
18211822
{

src/llama-model.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8367,7 +8367,6 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
83678367

83688368
for (int il = 0; il < n_layer; ++il) {
83698369
ggml_tensor * inpSA = inpL;
8370-
83718370
// norm
83728371
{
83738372
cur = build_norm(inpL,
@@ -8404,15 +8403,17 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
84048403
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
84058404
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
84068405

8406+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
8407+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
84078408
Qcur = ggml_rope_ext(
84088409
ctx0, Qcur, inp_pos, nullptr,
8409-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8410+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
84108411
ext_factor, attn_factor, beta_fast, beta_slow
84118412
);
84128413

84138414
Kcur = ggml_rope_ext(
84148415
ctx0, Kcur, inp_pos, nullptr,
8415-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8416+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
84168417
ext_factor, attn_factor, beta_fast, beta_slow
84178418
);
84188419

@@ -8435,7 +8436,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
84358436
cb(ffn_inp, "ffn_inp", il);
84368437

84378438
// feed-forward network
8438-
bool is_moe_layer = arch == LLM_ARCH_ERNIE4_5_MOE && hparams.n_moe_layer_step > 0 && (il + 1) % hparams.n_moe_layer_step == 0;
8439+
bool is_moe_layer = arch == LLM_ARCH_ERNIE4_5_MOE && hparams.n_moe_layer_step > 0 && il >= hparams.n_moe_layer_step;
84398440

84408441
if (!is_moe_layer) {
84418442
cur = build_norm(ffn_inp,
@@ -16828,6 +16829,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1682816829
case LLM_ARCH_SMOLLM3:
1682916830
case LLM_ARCH_ARCEE:
1683016831
case LLM_ARCH_ERNIE4_5:
16832+
case LLM_ARCH_ERNIE4_5_MOE:
1683116833
return LLAMA_ROPE_TYPE_NORM;
1683216834

1683316835
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)