Skip to content

Commit f2b0262

Browse files
Blaizzyawni
andauthored
Add Minimax-M2 (#568)
* add minimax m2 * fix dequant and decoder * remove unused * remove unused * normalize scores * refactor * fix minimax * fix --------- Co-authored-by: awni <[email protected]>
1 parent 367d6d7 commit f2b0262

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

mlx_lm/models/minimax.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Any, List, Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
from .switch_layers import SwitchGLU
11+
12+
13+
@dataclass
14+
class ModelArgs(BaseModelArgs):
15+
model_type: str
16+
hidden_size: int
17+
intermediate_size: int
18+
num_attention_heads: int
19+
num_key_value_heads: int
20+
max_position_embeddings: int
21+
num_experts_per_tok: int
22+
num_local_experts: int
23+
shared_intermediate_size: int
24+
num_hidden_layers: int
25+
rms_norm_eps: float
26+
rope_theta: float
27+
rotary_dim: int
28+
vocab_size: int
29+
tie_word_embeddings: bool = False
30+
scoring_func: str = "sigmoid"
31+
head_dim: Optional[int] = None
32+
use_qk_norm: bool = True
33+
34+
35+
class MiniMaxAttention(nn.Module):
36+
def __init__(self, args: ModelArgs):
37+
super().__init__()
38+
39+
self.hidden_dim = hidden_size = args.hidden_size
40+
41+
self.num_attention_heads = args.num_attention_heads
42+
self.num_key_value_heads = args.num_key_value_heads
43+
self.head_dim = head_dim = (
44+
args.head_dim or hidden_size // args.num_attention_heads
45+
)
46+
self.scale = head_dim**-0.5
47+
48+
self.q_proj = nn.Linear(
49+
args.hidden_size, self.num_attention_heads * head_dim, bias=False
50+
)
51+
self.k_proj = nn.Linear(
52+
args.hidden_size, self.num_key_value_heads * head_dim, bias=False
53+
)
54+
self.v_proj = nn.Linear(
55+
args.hidden_size, self.num_key_value_heads * head_dim, bias=False
56+
)
57+
self.o_proj = nn.Linear(
58+
self.num_attention_heads * head_dim, args.hidden_size, bias=False
59+
)
60+
61+
self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False
62+
if self.use_qk_norm:
63+
self.q_norm = nn.RMSNorm(
64+
head_dim * self.num_attention_heads, eps=args.rms_norm_eps
65+
)
66+
self.k_norm = nn.RMSNorm(
67+
head_dim * self.num_key_value_heads, eps=args.rms_norm_eps
68+
)
69+
70+
self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta)
71+
72+
def __call__(
73+
self,
74+
x: mx.array,
75+
mask: Optional[mx.array] = None,
76+
cache: Optional[Any] = None,
77+
) -> mx.array:
78+
B, L, D = x.shape
79+
80+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
81+
82+
if self.use_qk_norm:
83+
queries = self.q_norm(queries)
84+
keys = self.k_norm(keys)
85+
86+
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
87+
0, 2, 1, 3
88+
)
89+
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
90+
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
91+
0, 2, 1, 3
92+
)
93+
94+
if cache is not None:
95+
queries = self.rope(queries, offset=cache.offset)
96+
keys = self.rope(keys, offset=cache.offset)
97+
keys, values = cache.update_and_fetch(keys, values)
98+
else:
99+
queries = self.rope(queries)
100+
keys = self.rope(keys)
101+
102+
output = scaled_dot_product_attention(
103+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
104+
)
105+
106+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
107+
108+
return self.o_proj(output)
109+
110+
111+
class MiniMaxSparseMoeBlock(nn.Module):
112+
def __init__(self, args: ModelArgs):
113+
super().__init__()
114+
self.num_experts_per_tok = args.num_experts_per_tok
115+
116+
self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False)
117+
self.switch_mlp = SwitchGLU(
118+
args.hidden_size, args.intermediate_size, args.num_local_experts
119+
)
120+
self.e_score_correction_bias = mx.zeros((args.num_local_experts,))
121+
122+
def __call__(self, x: mx.array) -> mx.array:
123+
gates = self.gate(x.astype(mx.float32))
124+
125+
scores = mx.sigmoid(gates)
126+
orig_scores = scores
127+
scores = scores + self.e_score_correction_bias
128+
129+
k = self.num_experts_per_tok
130+
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
131+
scores = mx.take_along_axis(orig_scores, inds, axis=-1)
132+
133+
scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20)
134+
scores = scores.astype(x.dtype)
135+
136+
y = self.switch_mlp(x, inds)
137+
y = (y * scores[..., None]).sum(axis=-2)
138+
return y
139+
140+
141+
class MiniMaxDecoderLayer(nn.Module):
142+
def __init__(self, args: ModelArgs):
143+
super().__init__()
144+
145+
self.self_attn = MiniMaxAttention(args)
146+
147+
self.block_sparse_moe = MiniMaxSparseMoeBlock(args)
148+
149+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
150+
self.post_attention_layernorm = nn.RMSNorm(
151+
args.hidden_size, eps=args.rms_norm_eps
152+
)
153+
154+
def __call__(
155+
self,
156+
x: mx.array,
157+
mask: Optional[mx.array] = None,
158+
cache: Optional[Any] = None,
159+
) -> mx.array:
160+
r = x + self.self_attn(self.input_layernorm(x), mask, cache)
161+
r = r + self.block_sparse_moe(self.post_attention_layernorm(r))
162+
return r
163+
164+
165+
class MiniMaxModel(nn.Module):
166+
def __init__(self, args: ModelArgs):
167+
super().__init__()
168+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
169+
170+
self.layers = [
171+
MiniMaxDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
172+
]
173+
174+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
175+
176+
def __call__(
177+
self,
178+
inputs: mx.array,
179+
mask: Optional[mx.array] = None,
180+
cache: Optional[Any] = None,
181+
) -> mx.array:
182+
h = self.embed_tokens(inputs)
183+
184+
if cache is None:
185+
cache = [None] * len(self.layers)
186+
187+
mask = create_attention_mask(h, cache[0])
188+
189+
for layer, c in zip(self.layers, cache):
190+
h = layer(h, mask, c)
191+
192+
return self.norm(h)
193+
194+
195+
class Model(nn.Module):
196+
def __init__(self, args: ModelArgs):
197+
super().__init__()
198+
self.args = args
199+
self.model_type = args.model_type
200+
self.model = MiniMaxModel(args)
201+
if not args.tie_word_embeddings:
202+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
203+
204+
def __call__(
205+
self,
206+
inputs: mx.array,
207+
mask: Optional[mx.array] = None,
208+
cache: Optional[Any] = None,
209+
):
210+
out = self.model(inputs=inputs, mask=mask, cache=cache)
211+
if self.args.tie_word_embeddings:
212+
out = self.model.embed_tokens.as_linear(out)
213+
else:
214+
out = self.lm_head(out)
215+
return out
216+
217+
def sanitize(self, weights):
218+
"""Dequantize FP8 weights and restructure MoE experts."""
219+
220+
def dequant(weight, scale_inv):
221+
dtype = weight.dtype
222+
bs = 128 # block size
223+
m, n = weight.shape
224+
pad_bottom = (-m) % bs
225+
pad_side = (-n) % bs
226+
weight = mx.pad(weight, ((0, pad_bottom), (0, pad_side)))
227+
weight = weight.reshape(
228+
((m + pad_bottom) // bs, bs, (n + pad_side) // bs, bs)
229+
)
230+
weight = (weight * scale_inv[:, None, :, None]).reshape(
231+
m + pad_bottom, n + pad_side
232+
)
233+
return weight[:m, :n].astype(dtype)
234+
235+
# Dequantize
236+
new_weights = {}
237+
for k, v in weights.items():
238+
if "weight_scale_inv" in k:
239+
scale_inv = v
240+
wk = k.replace("_scale_inv", "")
241+
weight = weights[wk]
242+
weight = dequant(weight, scale_inv)
243+
new_weights[wk] = weight
244+
elif k not in new_weights:
245+
new_weights[k] = v
246+
weights = new_weights
247+
248+
# Step 2: Handle MoE expert weights restructuring
249+
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
250+
return weights
251+
252+
for l in range(self.args.num_hidden_layers):
253+
prefix = f"model.layers.{l}"
254+
mapping = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"}
255+
for orig_name, new_name in mapping.items():
256+
if f"{prefix}.block_sparse_moe.experts.0.{orig_name}.weight" in weights:
257+
to_join = [
258+
weights.pop(
259+
f"{prefix}.block_sparse_moe.experts.{e}.{orig_name}.weight"
260+
)
261+
for e in range(self.args.num_local_experts)
262+
]
263+
weights[
264+
f"{prefix}.block_sparse_moe.switch_mlp.{new_name}.weight"
265+
] = mx.stack(to_join)
266+
267+
return weights
268+
269+
@property
270+
def layers(self):
271+
return self.model.layers
272+
273+
@property
274+
def cast_predicate(self):
275+
def predicate(k):
276+
return "e_score_correction_bias" not in k
277+
278+
return predicate
279+
280+
@property
281+
def quant_predicate(self):
282+
def predicate(path, _):
283+
if path.endswith("block_sparse_moe.gate"):
284+
return {"group_size": 64, "bits": 8}
285+
return True
286+
287+
return predicate

tests/test_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,22 @@ def test_all_models(self):
19391939
"vocab_size": 32,
19401940
"intermediate_size": 128,
19411941
},
1942+
{
1943+
"model_type": "minimax",
1944+
"hidden_size": 128,
1945+
"intermediate_size": 128,
1946+
"num_attention_heads": 8,
1947+
"num_key_value_heads": 8,
1948+
"max_position_embeddings": 1000,
1949+
"num_experts_per_tok": 2,
1950+
"num_local_experts": 8,
1951+
"shared_intermediate_size": 128,
1952+
"num_hidden_layers": 4,
1953+
"rms_norm_eps": 1e-4,
1954+
"rope_theta": 1000,
1955+
"rotary_dim": 16,
1956+
"vocab_size": 1000,
1957+
},
19421958
]
19431959
for config in test_configs:
19441960
model_type = config["model_type"]

0 commit comments

Comments
 (0)