Skip to content

Commit 8501cb3

Browse files
committed
Add Ernie4.5 MoE
1 parent 7de5c7c commit 8501cb3

File tree

6 files changed

+343
-26
lines changed

6 files changed

+343
-26
lines changed

convert_hf_to_gguf.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2781,7 +2781,7 @@ def set_gguf_parameters(self):
27812781
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
27822782
num_heads = self.hparams["num_attention_heads"]
27832783
num_kv_heads = self.hparams["num_key_value_heads"]
2784-
head_dim = self.hparams["head_dim"]
2784+
head_dim = self.hparams["hidden_size"] // num_heads
27852785

27862786
if "ernie." in name:
27872787
name = name.replace("ernie.", "model.")
@@ -2814,6 +2814,87 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
28142814
return [(self.map_tensor_name(name), data_torch)]
28152815

28162816

2817+
@ModelBase.register("Ernie4_5_MoeForCausalLM")
2818+
class Ernie4_5MoeModel(Ernie4_5Model):
2819+
model_arch = gguf.MODEL_ARCH.ERNIE4_5_MOE
2820+
_experts: list[dict[str, Tensor]] | None = None
2821+
2822+
def __init__(self, *args, **kwargs):
2823+
super().__init__(*args, **kwargs)
2824+
self._experts = [{} for _ in range(self.block_count)]
2825+
2826+
def set_gguf_parameters(self):
2827+
super().set_gguf_parameters()
2828+
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
2829+
self.gguf_writer.add_expert_used_count(self.hparams["moe_k"])
2830+
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_layer_interval"])
2831+
2832+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2833+
# Modify correction bias name as in DeepseekV2
2834+
if name.endswith("e_score_correction_bias"):
2835+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
2836+
2837+
# skip Multi-Token Prediction (MTP) layers (again, same as DeepseekV2)
2838+
match = re.match(r"model.mtp_block.(\d+)", name)
2839+
if match:
2840+
return []
2841+
2842+
# skip all other MTP tensors for now
2843+
match = re.match(r"model.mtp_emb_norm.(\d+)", name)
2844+
if match:
2845+
return []
2846+
2847+
match = re.match(r"model.mtp_hidden_norm.(\d+)", name)
2848+
if match:
2849+
return []
2850+
2851+
match = re.match(r"model.mtp_linear_proj.(\d+)", name)
2852+
if match:
2853+
return []
2854+
2855+
# process the experts separately
2856+
if name.find("experts.") != -1 and name.find("shared") == -1:
2857+
n_experts = self.hparams["moe_num_experts"]
2858+
assert bid is not None
2859+
2860+
if self._experts is None:
2861+
self._experts = [{} for _ in range(self.block_count)]
2862+
2863+
self._experts[bid][name] = data_torch
2864+
2865+
if len(self._experts[bid]) >= n_experts * 3:
2866+
tensors: list[tuple[str, Tensor]] = []
2867+
2868+
# merge the experts into a single 3d tensor
2869+
for w_name in ["gate_proj", "up_proj", "down_proj"]:
2870+
datas: list[Tensor] = []
2871+
2872+
for xid in range(n_experts):
2873+
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2874+
datas.append(self._experts[bid][ename_to_retrieve])
2875+
del self._experts[bid][ename_to_retrieve]
2876+
2877+
data_torch = torch.stack(datas, dim=0)
2878+
merged_name = f"layers.{bid}.mlp.experts.{w_name}.weight"
2879+
new_name = self.map_tensor_name(merged_name)
2880+
tensors.append((new_name, data_torch))
2881+
2882+
return tensors
2883+
else:
2884+
return []
2885+
return [(self.map_tensor_name(name), data_torch)]
2886+
2887+
def prepare_tensors(self):
2888+
super().prepare_tensors()
2889+
2890+
if self._experts is not None:
2891+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2892+
experts = [k for d in self._experts for k in d.keys()]
2893+
if len(experts) > 0:
2894+
logger.warning(f"Unprocessed experts: {experts}")
2895+
raise ValueError(f"Unprocessed experts: {experts}")
2896+
2897+
28172898
@ModelBase.register(
28182899
"Qwen2VLModel",
28192900
"Qwen2VLForConditionalGeneration",

gguf-py/gguf/constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ class MODEL_ARCH(IntEnum):
363363
DOTS1 = auto()
364364
ARCEE = auto()
365365
ERNIE4_5 = auto()
366+
ERNIE4_5_MOE = auto()
366367
HUNYUAN_MOE = auto()
367368
SMOLLM3 = auto()
368369
LFM2 = auto()
@@ -677,6 +678,7 @@ class MODEL_TENSOR(IntEnum):
677678
MODEL_ARCH.DOTS1: "dots1",
678679
MODEL_ARCH.ARCEE: "arcee",
679680
MODEL_ARCH.ERNIE4_5: "ernie4_5",
681+
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5_moe",
680682
MODEL_ARCH.FALCON_H1: "falcon-h1",
681683
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
682684
MODEL_ARCH.SMOLLM3: "smollm3",
@@ -1973,6 +1975,28 @@ class MODEL_TENSOR(IntEnum):
19731975
MODEL_TENSOR.FFN_UP_SHEXP,
19741976
MODEL_TENSOR.FFN_EXP_PROBS_B,
19751977
],
1978+
MODEL_ARCH.ERNIE4_5_MOE: [
1979+
MODEL_TENSOR.TOKEN_EMBD,
1980+
MODEL_TENSOR.OUTPUT_NORM,
1981+
MODEL_TENSOR.OUTPUT,
1982+
MODEL_TENSOR.ATTN_NORM,
1983+
MODEL_TENSOR.ATTN_Q,
1984+
MODEL_TENSOR.ATTN_K,
1985+
MODEL_TENSOR.ATTN_V,
1986+
MODEL_TENSOR.ATTN_OUT,
1987+
MODEL_TENSOR.FFN_NORM,
1988+
MODEL_TENSOR.FFN_GATE,
1989+
MODEL_TENSOR.FFN_DOWN,
1990+
MODEL_TENSOR.FFN_UP,
1991+
MODEL_TENSOR.FFN_GATE_INP,
1992+
MODEL_TENSOR.FFN_GATE_EXP,
1993+
MODEL_TENSOR.FFN_DOWN_EXP,
1994+
MODEL_TENSOR.FFN_UP_EXP,
1995+
MODEL_TENSOR.FFN_GATE_SHEXP,
1996+
MODEL_TENSOR.FFN_DOWN_SHEXP,
1997+
MODEL_TENSOR.FFN_UP_SHEXP,
1998+
MODEL_TENSOR.FFN_EXP_PROBS_B,
1999+
],
19762000
MODEL_ARCH.PLM: [
19772001
MODEL_TENSOR.TOKEN_EMBD,
19782002
MODEL_TENSOR.OUTPUT,

gguf-py/gguf/tensor_mapping.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,16 @@ class TensorNameMap:
311311
"model.layers.{bid}.feed_forward.router", # llama4 jamba
312312
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
313313
"model.layers.{bid}.mlp.gate.wg", # hunyuan
314+
"model.layers.{bid}.mlp.ffn_gate_inp.weight", # ernie4.5-moe
314315
),
315316

316317
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
317318
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
318319
),
319320

320321
MODEL_TENSOR.FFN_EXP_PROBS_B: (
321-
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
322+
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
323+
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
322324
),
323325

324326
# Feed-forward up
@@ -357,13 +359,14 @@ class TensorNameMap:
357359
),
358360

359361
MODEL_TENSOR.FFN_UP_EXP: (
360-
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
361-
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
362-
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
363-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
364-
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
365-
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
366-
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
362+
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
363+
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
364+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
365+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
366+
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
367+
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
368+
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
369+
"layers.{bid}.mlp.experts.up_proj.weight", # ernie4.5-moe
367370
),
368371

369372
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -396,12 +399,13 @@ class TensorNameMap:
396399
),
397400

398401
MODEL_TENSOR.FFN_GATE_EXP: (
399-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
400-
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
401-
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
402-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
403-
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
404-
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
402+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
403+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
404+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
405+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
406+
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
407+
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
408+
"layers.{bid}.mlp.experts.gate_proj.weight", # ernie4.5-moe
405409
),
406410

407411
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -443,14 +447,15 @@ class TensorNameMap:
443447
),
444448

445449
MODEL_TENSOR.FFN_DOWN_EXP: (
446-
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
447-
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
448-
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
449-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
450-
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
451-
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
452-
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
453-
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
450+
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
451+
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
452+
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
453+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
454+
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
455+
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
456+
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
457+
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
458+
"layers.{bid}.mlp.experts.down_proj.weight", # ernie4.5-moe
454459
),
455460

456461
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8181
{ LLM_ARCH_DOTS1, "dots1" },
8282
{ LLM_ARCH_ARCEE, "arcee" },
8383
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
84+
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5_moe" },
8485
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
8586
{ LLM_ARCH_SMOLLM3, "smollm3" },
8687
{ LLM_ARCH_LFM2, "lfm2" },
@@ -1793,6 +1794,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
17931794
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
17941795
},
17951796
},
1797+
{
1798+
LLM_ARCH_ERNIE4_5_MOE,
1799+
{
1800+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1801+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1802+
{ LLM_TENSOR_OUTPUT, "output" },
1803+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1804+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1805+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1806+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1807+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1808+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1809+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1810+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1811+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1812+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1813+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1814+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1815+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1816+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1817+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1818+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1819+
},
1820+
},
17961821
{
17971822
LLM_ARCH_HUNYUAN_MOE,
17981823
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ enum llm_arch {
8585
LLM_ARCH_DOTS1,
8686
LLM_ARCH_ARCEE,
8787
LLM_ARCH_ERNIE4_5,
88+
LLM_ARCH_ERNIE4_5_MOE,
8889
LLM_ARCH_HUNYUAN_MOE,
8990
LLM_ARCH_SMOLLM3,
9091
LLM_ARCH_LFM2,

0 commit comments

Comments
 (0)