Skip to content

Commit 06ed421

Browse files
committed
Model: Minimax M2
1 parent 851553e commit 06ed421

14 files changed

+470
-1
lines changed

convert_hf_to_gguf.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
10541054
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
10551055
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
10561056
res = "granite-docling"
1057+
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
1058+
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
1059+
res = "minimax-m2"
10571060

10581061
if res is None:
10591062
logger.warning("\n")
@@ -6909,6 +6912,85 @@ def prepare_tensors(self):
69096912
raise ValueError(f"Unprocessed experts: {experts}")
69106913

69116914

6915+
@ModelBase.register("MiniMaxM2ForCausalLM")
6916+
class MiniMaxM2Model(TextModel):
6917+
model_arch = gguf.MODEL_ARCH.MINIMAXM2
6918+
_experts_cache: dict[int, dict[str, Tensor]] = {}
6919+
6920+
def __init__(self, *args, **kwargs):
6921+
super().__init__(*args, **kwargs)
6922+
self.hparams["num_experts"] = self.hparams["num_local_experts"]
6923+
6924+
def set_gguf_parameters(self):
6925+
if self.hparams["scoring_func"] == "sigmoid":
6926+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
6927+
elif self.hparams["scoring_func"] == "softmax":
6928+
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
6929+
else:
6930+
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
6931+
6932+
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
6933+
n_embd = self.find_hparam(["hidden_size", "n_embd"])
6934+
n_head = self.find_hparam(["num_attention_heads", "n_head"])
6935+
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
6936+
rms_eps = self.find_hparam(["rms_norm_eps"])
6937+
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
6938+
head_dim = self.find_hparam(["head_dim"])
6939+
6940+
self.gguf_writer.add_context_length(max_pos_embds)
6941+
self.gguf_writer.add_embedding_length(n_embd)
6942+
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
6943+
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
6944+
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts"]))
6945+
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok"]))
6946+
self.gguf_writer.add_block_count(block_count)
6947+
self.gguf_writer.add_head_count(n_head)
6948+
self.gguf_writer.add_head_count_kv(n_head_kv)
6949+
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
6950+
self.gguf_writer.add_layer_norm_eps(rms_eps)
6951+
self.gguf_writer.add_key_length(head_dim)
6952+
self.gguf_writer.add_value_length(head_dim)
6953+
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
6954+
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
6955+
6956+
6957+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
6958+
if name.endswith("e_score_correction_bias"):
6959+
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
6960+
6961+
# merge expert weights
6962+
if 'experts' in name:
6963+
n_experts = self.hparams["num_experts"]
6964+
assert bid is not None
6965+
6966+
expert_cache = self._experts_cache.setdefault(bid, {})
6967+
expert_cache[name] = data_torch
6968+
expert_weights = ["w1", "w2", "w3"]
6969+
6970+
# not enough expert weights to merge
6971+
if len(expert_cache) < n_experts * len(expert_weights):
6972+
return []
6973+
6974+
tensors: list[tuple[str, Tensor]] = []
6975+
for w_name in expert_weights:
6976+
datas: list[Tensor] = []
6977+
6978+
for xid in range(n_experts):
6979+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
6980+
datas.append(expert_cache[ename])
6981+
del expert_cache[ename]
6982+
6983+
data_torch = torch.stack(datas, dim=0)
6984+
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
6985+
new_name = self.map_tensor_name(merged_name)
6986+
tensors.append((new_name, data_torch))
6987+
6988+
del self._experts_cache[bid]
6989+
return tensors
6990+
6991+
return super().modify_tensors(data_torch, name, bid)
6992+
6993+
69126994
@ModelBase.register("Dots1ForCausalLM")
69136995
class Dots1Model(Qwen2MoeModel):
69146996
model_arch = gguf.MODEL_ARCH.DOTS1

convert_hf_to_gguf_update.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class TOKENIZER_TYPE(IntEnum):
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142142
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
143143
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
144+
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
144145
]
145146

146147
# some models are known to be broken upstream, so we will skip them as exceptions
@@ -420,6 +421,8 @@ def get_vocab_base_pre(self, tokenizer) -> str:
420421
# with each model, encode all tests and write the results in ./models/ggml-vocab-{name}.gguf.out
421422
# for each test, write the resulting tokens on a separate line
422423

424+
print(f"Have models: {models}\n\n")
425+
423426
for model in models:
424427
name = model["name"]
425428
tokt = model["tokt"]
@@ -438,6 +441,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
438441
except OSError as e:
439442
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
440443
continue # Skip this model and continue with the next one in the loop
444+
except TypeError as e:
445+
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
446+
continue # Skip this model and continue with the next one in the loop
441447

442448
if not os.path.exists(f"models/ggml-vocab-{name}.gguf"):
443449
logger.info(f"Skip vocab files for model {name}, no GGUF file found")

gguf-py/gguf/constants.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ class MODEL_ARCH(IntEnum):
420420
SEED_OSS = auto()
421421
GROVEMOE = auto()
422422
APERTUS = auto()
423-
423+
MINIMAXM2 = auto()
424424

425425
class VISION_PROJECTOR_TYPE(IntEnum):
426426
MLP = auto()
@@ -766,6 +766,7 @@ class MODEL_TENSOR(IntEnum):
766766
MODEL_ARCH.SEED_OSS: "seed_oss",
767767
MODEL_ARCH.GROVEMOE: "grovemoe",
768768
MODEL_ARCH.APERTUS: "apertus",
769+
MODEL_ARCH.MINIMAXM2: "minimax-m2",
769770
}
770771

771772
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2837,6 +2838,25 @@ class MODEL_TENSOR(IntEnum):
28372838
MODEL_TENSOR.FFN_DOWN_CHEXP,
28382839
MODEL_TENSOR.FFN_UP_CHEXP,
28392840
],
2841+
MODEL_ARCH.MINIMAXM2: [
2842+
MODEL_TENSOR.TOKEN_EMBD,
2843+
MODEL_TENSOR.OUTPUT_NORM,
2844+
MODEL_TENSOR.OUTPUT,
2845+
MODEL_TENSOR.ATTN_NORM,
2846+
MODEL_TENSOR.ATTN_Q,
2847+
MODEL_TENSOR.ATTN_Q_NORM,
2848+
MODEL_TENSOR.ATTN_K,
2849+
MODEL_TENSOR.ATTN_K_NORM,
2850+
MODEL_TENSOR.ATTN_V,
2851+
MODEL_TENSOR.ATTN_OUT,
2852+
MODEL_TENSOR.FFN_NORM,
2853+
MODEL_TENSOR.FFN_GATE_INP,
2854+
MODEL_TENSOR.FFN_GATE_EXP,
2855+
MODEL_TENSOR.FFN_DOWN_EXP,
2856+
MODEL_TENSOR.FFN_UP_EXP,
2857+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2858+
],
2859+
28402860
# TODO
28412861
}
28422862

gguf-py/gguf/tensor_mapping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ class TensorNameMap:
377377
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
378378
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
379379
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
380+
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
380381
),
381382

382383
# Feed-forward up

models/ggml-vocab-minimax-m2.gguf

7.85 MB
Binary file not shown.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
ied 4 ½ months
2+
__ggml_vocab_test__
3+
Äpfel
4+
__ggml_vocab_test__
5+
6+
__ggml_vocab_test__
7+
8+
__ggml_vocab_test__
9+
10+
__ggml_vocab_test__
11+
12+
__ggml_vocab_test__
13+
14+
__ggml_vocab_test__
15+
16+
17+
__ggml_vocab_test__
18+
19+
20+
21+
__ggml_vocab_test__
22+
23+
24+
25+
26+
__ggml_vocab_test__
27+
28+
29+
__ggml_vocab_test__
30+
Hello world
31+
__ggml_vocab_test__
32+
Hello world
33+
__ggml_vocab_test__
34+
Hello World
35+
__ggml_vocab_test__
36+
Hello World
37+
__ggml_vocab_test__
38+
Hello World!
39+
__ggml_vocab_test__
40+
Hello, world!
41+
__ggml_vocab_test__
42+
Hello, world!
43+
__ggml_vocab_test__
44+
this is 🦙.cpp
45+
__ggml_vocab_test__
46+
w048 7tuijk dsdfhu
47+
__ggml_vocab_test__
48+
нещо на Български
49+
__ggml_vocab_test__
50+
កាន់តែពិសេសអាចខលចេញ
51+
__ggml_vocab_test__
52+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
53+
__ggml_vocab_test__
54+
Hello
55+
__ggml_vocab_test__
56+
Hello
57+
__ggml_vocab_test__
58+
Hello
59+
__ggml_vocab_test__
60+
Hello
61+
__ggml_vocab_test__
62+
Hello
63+
__ggml_vocab_test__
64+
Hello
65+
Hello
66+
__ggml_vocab_test__
67+
(
68+
__ggml_vocab_test__
69+
70+
=
71+
__ggml_vocab_test__
72+
' era
73+
__ggml_vocab_test__
74+
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
75+
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
78+
3
79+
__ggml_vocab_test__
80+
33
81+
__ggml_vocab_test__
82+
333
83+
__ggml_vocab_test__
84+
3333
85+
__ggml_vocab_test__
86+
33333
87+
__ggml_vocab_test__
88+
333333
89+
__ggml_vocab_test__
90+
3333333
91+
__ggml_vocab_test__
92+
33333333
93+
__ggml_vocab_test__
94+
333333333
95+
__ggml_vocab_test__
96+
Cửa Việt
97+
__ggml_vocab_test__
98+
discards
99+
__ggml_vocab_test__
100+
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
112+
__ggml_vocab_test__
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
1233 32 52 32 23901 4632
2+
69967 30230 295
3+
4+
32
5+
256
6+
326
7+
9
8+
10
9+
367
10+
4368
11+
10380
12+
19739 2035
13+
53398 2035
14+
19739 5476
15+
53398 5476
16+
53398 5476 33
17+
19739 44 2035 33
18+
53398 44 2035 33
19+
546 355 9753 166 153 46 52243
20+
119 48218 32 55 116 2157 60350 40081 6107 15931
21+
8827 40614 3642 11575 185034 8623
22+
76300 128 76300 182 76300 147 157246 139 76300 143 157246 130 76300 150 76300 183 76300 159 225 35097 76300 159 76300 162 76300 182 76300 133 76300 129 76300 155 76300 133 225 35097 76300 137
23+
150333 359 14291 41 19918 182 61587 79213 171 21243 359 79401 158243 176756 41 181343 359 10141 113958 389 760 1072 1813 11248 41
24+
19739
25+
53398
26+
32 53398
27+
256 53398
28+
326 53398
29+
326 53398 10 326 53398
30+
359
31+
10 409
32+
39 5784
33+
19739 44 330 53147 33 2329 457 390 184404 3479 32020 594 44450 2489 17246 35341 49 1419 5516
34+
34485 6255
35+
51
36+
2893
37+
18397
38+
18397 51
39+
18397 2893
40+
18397 18397
41+
18397 18397 51
42+
18397 18397 2893
43+
18397 18397 18397
44+
67 191937 97 31042 84408 116
45+
2300 2958
46+
137106 35066 24361 56254 151540 4315 10877 7671 41564 150333 359 14291 41 19918 182 61587 79213 171 21243 359 79401 158243 176756 41 181343 9753 166 153 186278 153 32 51 32 2893 32 18397 32 18397 51 32 18397 2893 32 18397 18397 32 18397 18397 51 32 18397 18397 2893 32 51 46 51 32 51 645 51 32 51 1662 51 29559 158 128 76300 182 76300 147 157246 139 76300 143 157246 130 76300 150 76300 183 76300 159 225 35097 76300 159 76300 162 76300 182 76300 133 21557 129 3479 32020 594 44450 2489 17246 35341 49 1419 5516 109618 1246 9435 6833 40614 3642 11575 185034 8623 8462 3443 64346 2765 111832 22815 34485 6255 61018 13074 8244 1040 722 116 1186 13396 986 44 722 2380 390 3123 63 722 77 516 3123 13098 1454 412 44 722 68 390 1079 1001 17251 63 1559 39 34121 258 99132 76

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
103103
{ LLM_ARCH_SEED_OSS, "seed_oss" },
104104
{ LLM_ARCH_GROVEMOE, "grovemoe" },
105105
{ LLM_ARCH_APERTUS, "apertus" },
106+
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
106107
{ LLM_ARCH_UNKNOWN, "(unknown)" },
107108
};
108109

@@ -2312,6 +2313,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
23122313
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
23132314
},
23142315
},
2316+
{
2317+
LLM_ARCH_MINIMAX_M2,
2318+
{
2319+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2320+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2321+
{ LLM_TENSOR_OUTPUT, "output" },
2322+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2323+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2324+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2325+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2326+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2327+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2328+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2329+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2330+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2331+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2332+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2333+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2334+
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
2335+
},
2336+
},
23152337
{
23162338
LLM_ARCH_UNKNOWN,
23172339
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ enum llm_arch {
107107
LLM_ARCH_SEED_OSS,
108108
LLM_ARCH_GROVEMOE,
109109
LLM_ARCH_APERTUS,
110+
LLM_ARCH_MINIMAX_M2,
110111
LLM_ARCH_UNKNOWN,
111112
};
112113

0 commit comments

Comments
 (0)