Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8957,6 +8957,142 @@
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))


@ModelBase.register("MegrezMoEForCausalLM")
class MegrezMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.MEGREZ_MOE

def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = getattr(tokenizer, "mergeable_ranks", {})
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))

vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = getattr(tokenizer, "special_tokens", {})
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# BOS token fix if needed
# self.gguf_writer.add_bos_token_id(<id>)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])

moe_intermediate_size = hparams["moe_intermediate_size"]
if moe_intermediate_size is not None and isinstance(moe_intermediate_size, (list, tuple)) and len(moe_intermediate_size) > 0:
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(int(moe_intermediate_size[0]))

moe_topk = hparams["moe_topk"]
if moe_topk is not None and isinstance(moe_topk, (list, tuple)) and len(moe_topk) > 0:
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(int(moe_topk[0]))

moe_shared_expert = hparams["num_shared_expert"]
if moe_shared_expert is not None and isinstance(moe_shared_expert, (list, tuple)) and len(moe_shared_expert) > 0:
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
self.gguf_writer.add_expert_shared_count(int(moe_shared_expert[0]))

rope_scaling = hparams.get("rope_scaling", {})
if rope_scaling.get("type") == "dynamic":
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
hidden_size = hparams.get("hidden_size")
num_attention_heads = hparams.get("num_attention_heads")
max_position_embeddings = self.hparams.get("max_position_embeddings")
if None not in (hidden_size, num_attention_heads, max_position_embeddings):
dim = hidden_size // num_attention_heads

Check failure on line 9039 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Operator "//" not supported for "None" (reportOptionalOperand)
scaled_base = base * (alpha ** (dim / (dim - 2)))
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024)
self.gguf_writer.add_context_length(256 * 1024)
assert alpha == 1000 and base == 10000.0 and dim == 128 and max_position_embeddings in [32 * 1024, 256 * 1024], \
"Megrez dynamic RoPE scaling assumptions changed, please update the logic or context length manually"

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []

if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

# ...existing code...

@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
COGVLM = auto()
MINIMAXM2 = auto()
PANGU_EMBED = auto()
MEGREZ_MOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -795,6 +796,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MEGREZ_MOE: "megrez-moe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ add_library(llama
models/mamba.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
models/megrez-moe.cpp
models/mpt.cpp
models/nemotron-h.cpp
models/nemotron.cpp
Expand Down
26 changes: 26 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MEGREZ_MOE, "megrez-moe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2378,6 +2379,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_MEGREZ_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PANGU_EMBED,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum llm_arch {
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MEGREZ_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down
16 changes: 15 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,21 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
uint32_t base_nodes = std::max<uint32_t>(1024u, 8u*model.n_tensors());

// Megrez-MoE creates many intermediate tensors in build_mergez_moe_ffn for each layer:
// - sigmoid, add (bias), reshape (3x), get_rows, sum_rows, div, view_2d, mul_mat (per expert)
// - ggml_top_k internally calls ggml_argsort + ggml_view_4d (2 more tensors per layer)
// Each MoE layer needs ~30-40 intermediate tensors during graph construction
// With 30 MoE layers, this adds significant overhead to the graph (30 layers * 35 tensors = ~1050)
// During warmup, the graph is built 3 times with different batch sizes
if (model.arch == LLM_ARCH_MEGREZ_MOE) {
// Add substantial overhead: ~35 intermediate tensors per MoE layer * 30 layers = ~1050 nodes
// Double it to 4096 for safety margin during warmup's triple graph construction
base_nodes += 4096;
}

return base_nodes;
}

llm_graph_result * llama_context::get_gf_res_reserve() const {
Expand Down
78 changes: 78 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2180,12 +2180,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_PANGU_EMBED:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);

switch (hparams.n_layer) {
case 31: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}

Expand Down Expand Up @@ -3338,6 +3352,65 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

// Layer 0 is dense, layers 1-30 are MoE
if (i == 0) {
// Dense layer
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
} else {
// All MoE layers (1-30) have these
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0);

if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE");
}

// All MoE layers have shared expert
const int64_t n_ff_shexp = hparams.n_ff_shexp;
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);

// Only layers 1, 4, 7, 10, 13, 16, 19, 22, 25, 28 have actual expert tensors
// Pattern: (i-1) % 3 == 0 for i > 0
if ((i - 1) % 3 == 0) {
// MoE branch - use the expert-specific FF size from hparams
const int64_t n_ff_exp = hparams.n_ff_exp;

layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
}
// Note: layers that share experts (2, 3, 5, 6, etc.) only have gate_inp and shared expert
// They will reference the regular experts from their corresponding "full" layer during inference
}
}
} break;
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3VL:
{
Expand Down Expand Up @@ -7178,6 +7251,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_jais>(*this, params);
} break;
case LLM_ARCH_MEGREZ_MOE:
{
llm = std::make_unique<llm_build_megrez_moe>(*this, params);
} break;
case LLM_ARCH_NEMOTRON:
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
Expand Down Expand Up @@ -7518,6 +7595,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_CODESHELL:
case LLM_ARCH_ORION:
case LLM_ARCH_MEGREZ_MOE:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4:
Expand Down
Loading
Loading