diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 222f6ed6d..cb678db81 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1048,6 +1048,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": # ref: https://huggingface.co/JetBrains/Mellum-4b-base res = "mellum" + if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df": + # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer + res = "afmoe" if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0 res = "bailingmoe2" @@ -2457,6 +2460,100 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) +@ModelBase.register("AfmoeForCausalLM") +class AfmoeModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.AFMOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # MoE parameters + if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None: + self.gguf_writer.add_leading_dense_block_count(n_dense_layers) + + # Gating function (sigmoid) + if (score_func := self.hparams.get("score_func")) is not None and score_func == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Route normalization and scaling + if (route_norm := self.hparams.get("route_norm")) is not None: + self.gguf_writer.add_expert_weights_norm(route_norm) + if (route_scale := self.hparams.get("route_scale")) is not None: + self.gguf_writer.add_expert_weights_scale(route_scale) + + # Sliding window attention + if (sliding_window := self.hparams.get("sliding_window")) is not None: + self.gguf_writer.add_sliding_window(sliding_window) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Handle expert weights - they're already merged in the HF format + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["gate_proj", "up_proj", "down_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename_to_retrieve]) + del self._experts[bid][ename_to_retrieve] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + # Map attention gate + elif ".self_attn.gate_proj." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), data_torch)] + + # Map shared experts + elif ".mlp.shared_experts.gate_proj." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), data_torch)] + elif ".mlp.shared_experts.up_proj." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), data_torch)] + elif ".mlp.shared_experts.down_proj." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, bid), data_torch)] + + # Pre FFN norm + elif ".pre_mlp_layernorm." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_PRE_NORM, bid), data_torch)] + + # Post FFN norm + elif ".post_mlp_layernorm." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_POST_NORM, bid), data_torch)] + + # Map router + elif ".mlp.router.gate." in name and bid is not None: + return [(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_INP, bid), data_torch)] + + if name.endswith(".expert_bias"): + name = name.replace(".expert_bias", ".expert_bias.bias") + + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register( "LlavaForConditionalGeneration", # pixtral "Mistral3ForConditionalGeneration", # mistral small 3.1 diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 7df96eb08..b8f694e86 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, + {"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", }, {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6b4b6c5ab..1cd0efad4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum): BAILINGMOE2 = auto() DOTS1 = auto() ARCEE = auto() + AFMOE = auto() ERNIE4_5 = auto() ERNIE4_5_MOE = auto() HUNYUAN_MOE = auto() @@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum): ATTN_POST_NORM = auto() ATTN_ROT_EMBD = auto() ATTN_SINKS = auto() + ATTN_GATE = auto() FFN_GATE_INP = auto() FFN_GATE_INP_SHEXP = auto() FFN_NORM = auto() @@ -776,6 +778,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE2: "bailingmoe2", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.AFMOE: "afmoe", MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", MODEL_ARCH.FALCON_H1: "falcon-h1", @@ -828,6 +831,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks", + MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", @@ -2693,6 +2697,33 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.AFMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.ERNIE4_5: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 929406687..0dd2b6a12 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -340,11 +340,12 @@ class TensorNameMap: "model.layers.{bid}.feedforward_layernorm", # apertus ), - # Post feed-forward norm + # Pre feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 "layers.{bid}.pre_feedforward_layernorm", # embeddinggemma "model.layers.{bid}.pre_ff_layernorm.weight", + "model.layers.{bid}.pre_mlp_layernorm", # afmoe ), # Post feed-forward norm @@ -380,6 +381,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe "model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2 + "model.layers.{bid}.mlp.expert_bias", # afmoe "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 ), diff --git a/models/ggml-vocab-afmoe.gguf b/models/ggml-vocab-afmoe.gguf new file mode 100644 index 000000000..11271bfae Binary files /dev/null and b/models/ggml-vocab-afmoe.gguf differ diff --git a/models/ggml-vocab-afmoe.gguf.inp b/models/ggml-vocab-afmoe.gguf.inp new file mode 100644 index 000000000..86b934e40 --- /dev/null +++ b/models/ggml-vocab-afmoe.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Äpfel +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-afmoe.gguf.out b/models/ggml-vocab-afmoe.gguf.out new file mode 100644 index 000000000..694d32628 --- /dev/null +++ b/models/ggml-vocab-afmoe.gguf.out @@ -0,0 +1,46 @@ + 1129 252 51 252 20861 3621 + 49116 25524 343 + + 252 + 288 + 344 + 229 + 230 + 327 + 1866 + 4402 + 14795 1117 + 30197 1117 + 14795 3295 + 30197 3295 + 30197 3295 32 + 14795 43 1117 32 + 30197 43 1117 32 + 483 351 69865 279 45 11112 + 118 18799 252 54 115 4546 30869 25372 4191 13934 + 23835 183893 7432 30515 125974 185839 20324 + 124940 92255 273 160060 191869 44968 256 188211 21207 147 142156 195704 142156 21207 127 92255 259 21207 255 190792 21207 259 195704 21207 263 + 12479 387 10171 40 22860 146 18932 15540 136 10094 387 49707 77415 91293 40 70574 387 9266 56494 384 651 692 1204 9776 40 + 14795 + 30197 + 252 30197 + 288 30197 + 344 30197 + 344 30197 230 344 30197 + 387 + 230 399 + 38 6260 + 14795 43 366 76896 32 822 429 383 22860 255 2972 111778 3712 27304 19409 48 23988 18044 13814 73996 + 183574 + 50 + 2158 + 11805 + 50 11805 + 2158 11805 + 11805 11805 + 50 11805 11805 + 2158 11805 11805 + 11805 11805 11805 + 66 70789 96 140747 + 104867 + 144635 20623 120822 22300 4402 71947 2759 24373 12479 387 10171 40 22860 146 18932 15540 136 10094 387 49707 77415 91293 40 70574 69865 279 63816 279 252 50 252 2158 252 11805 252 50 11805 252 2158 11805 252 11805 11805 252 50 11805 11805 252 2158 11805 11805 252 50 45 50 252 50 634 50 252 50 1472 50 252 124940 92255 273 160060 191869 44968 256 188211 21207 147 142156 195704 142156 21207 127 92255 259 45614 255 2972 111778 3712 27304 19409 48 23988 18044 13814 73996 79520 1235 23427 13373 183893 7432 30515 125974 185839 20324 27123 36632 25121 3124 36057 36678 183574 31148 10446 365 1908 874 578 63490 438 414 765 43 578 1954 383 2259 62 578 76 487 2259 365 2130 960 394 43 578 67 383 679 766 8748 62 1155 38 35185 290 66450 75 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 630b2cddf..651834a00 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -35,6 +35,7 @@ add_library(llama unicode-data.cpp unicode.cpp unicode.h + models/afmoe.cpp models/apertus.cpp models/arcee.cpp models/arctic.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568..b2eb2477f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -90,6 +90,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BAILINGMOE2, "bailingmoe2" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_AFMOE, "afmoe" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, @@ -333,6 +334,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_AFMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_LLAMA4, { @@ -2444,6 +2475,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index a769dd1e8..ae7fa222a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -94,6 +94,7 @@ enum llm_arch { LLM_ARCH_BAILINGMOE2, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, + LLM_ARCH_AFMOE, LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, @@ -312,6 +313,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1987135ca..124055046 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -695,6 +695,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_AFMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + // Set up interleaved sliding window attention (ISWA) + // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + // Default to sigmoid if not set + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5763,6 +5795,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_AFMOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // dual attention normalization + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + // attention projections + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Q/K normalization + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // attention gating + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + + // dual ffn normalization + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { + // MoE layers + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + + // grouped expert weights + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + } + } else { + // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: { @@ -7256,6 +7353,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_AFMOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_ERNIE4_5: { llm = std::make_unique(*this, params); @@ -7478,7 +7579,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: - case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: return LLAMA_ROPE_TYPE_NORM; @@ -7537,6 +7637,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINIMAX_M2: case LLM_ARCH_COGVLM: case LLM_ARCH_PANGU_EMBED: + case LLM_ARCH_ARCEE: + case LLM_ARCH_AFMOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 71ff148e0..4e812bccb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -234,6 +234,7 @@ struct llama_layer { struct ggml_tensor * wk_enc = nullptr; struct ggml_tensor * wv_enc = nullptr; struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wqkv_gate = nullptr; // attention bias struct ggml_tensor * bq = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 735c5d547..c8262e292 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -443,6 +443,16 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_AFMOE: + regex_exprs = { + // Digits in groups of 1-3 + "\\p{N}{1,3}", + // CJK and Asian scripts (using direct Unicode literals) + "[一-鿿㐀-䶿豈-﫿぀-ゟ゠-ヿ・-゚⼀-⿟เ-๿຀-໿ក-៿က-႟ꩠ-ꩿꧠ-꧿가-힯ᄀ-ᇿ]+", + // Main BPE pattern + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1993,6 +2003,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "grok-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; clean_spaces = false; + } else if ( + tokenizer_pre == "afmoe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE; + clean_spaces = false; } else if ( tokenizer_pre == "minimax-m2") { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 1194ec473..55f8f3923 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -50,6 +50,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, }; struct LLM_KV; diff --git a/src/models/afmoe.cpp b/src/models/afmoe.cpp new file mode 100644 index 000000000..a80aef20a --- /dev/null +++ b/src/models/afmoe.cpp @@ -0,0 +1,187 @@ +#include "models.h" + +llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // MuP scaling: embeddings * sqrt(hidden_size) + // mup_enabled = true, hidden_size = 1024, scale = 32.0 + inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd))); + cb(inpL, "inp_embd_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // dual attention normalization (pre) + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * attn_inp = cur; // save input for gate computation + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // compute gate from input + ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); + cb(gate, "attn_gate_proj", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + // Q/K normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + // RoPE only for sliding_attention layers (every 4th layer is full_attention) + // layer_types[i] = "sliding_attention" if (i+1) % global_attn_every_n_layers != 0 + bool is_sliding = ((il + 1) % 4) != 0; // global_attn_every_n_layers = 4 + if (is_sliding) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_rope", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur_rope", il); + } + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, + NULL, NULL, // wo will be applied after gating + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + + // attention gating: attn_out * sigmoid(gate) BEFORE o_proj + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "attn_gate_sig", il); + cur = ggml_mul(ctx0, cur, gate); + cb(cur, "attn_gated", il); + + // now apply output projection + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_o_proj", il); + } + + // dual attention normalization (post) + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // dual ffn normalization (pre) + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MoE or dense FFN + if ((uint32_t)il >= hparams.n_layer_dense_lead) { + // MoE layer with sigmoid routing, normalization, and scaling + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, + hparams.expert_weights_norm != 0, // norm_w (route_norm=True) + hparams.expert_weights_scale != 0.0f, // scale_w + hparams.expert_weights_scale, // w_scale (route_scale=2.826) + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // shared expert + if (hparams.n_expert_shared > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else { + // dense layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + // dual ffn normalization (post) + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} \ No newline at end of file diff --git a/src/models/models.h b/src/models/models.h index 2fffb382d..4d7aeb4f4 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +struct llm_build_afmoe : public llm_graph_context { + llm_build_afmoe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_apertus : public llm_graph_context { llm_build_apertus(const llama_model & model, const llm_graph_params & params); };