Skip to content

Commit 056ab44

Browse files
committed
Properly encode/decode MoE layer step
1 parent 4a231eb commit 056ab44

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,10 +2827,10 @@ def set_gguf_parameters(self):
28272827
super().set_gguf_parameters()
28282828
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
28292829
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
2830-
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_layer_interval"])
2830+
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["moe_layer_interval"])
28312831

28322832
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
2833-
if "experts" in new_name:
2833+
if "exps" in new_name:
28342834
return gguf.GGMLQuantizationType.F16
28352835
return super().tensor_force_quant(name, new_name, bid, n_dims)
28362836

src/llama-model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16101610
case LLM_ARCH_ERNIE4_5_MOE:
16111611
{
16121612
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1613+
if (arch == LLM_ARCH_ERNIE4_5_MOE) {
1614+
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
1615+
}
16131616
switch (hparams.n_layer) {
16141617
case 18: type = LLM_TYPE_0_3B; break;
16151618
default: type = LLM_TYPE_UNKNOWN;

0 commit comments

Comments
 (0)