|
| 1 | +# Copyright © 2025 Apple Inc. |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, List, Optional |
| 5 | + |
| 6 | +import mlx.core as mx |
| 7 | +import mlx.nn as nn |
| 8 | + |
| 9 | +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention |
| 10 | +from .switch_layers import SwitchGLU |
| 11 | + |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class ModelArgs(BaseModelArgs): |
| 15 | + model_type: str |
| 16 | + hidden_size: int |
| 17 | + intermediate_size: int |
| 18 | + num_attention_heads: int |
| 19 | + num_key_value_heads: int |
| 20 | + max_position_embeddings: int |
| 21 | + num_experts_per_tok: int |
| 22 | + num_local_experts: int |
| 23 | + shared_intermediate_size: int |
| 24 | + num_hidden_layers: int |
| 25 | + rms_norm_eps: float |
| 26 | + rope_theta: float |
| 27 | + rotary_dim: int |
| 28 | + vocab_size: int |
| 29 | + tie_word_embeddings: bool = False |
| 30 | + scoring_func: str = "sigmoid" |
| 31 | + head_dim: Optional[int] = None |
| 32 | + use_qk_norm: bool = True |
| 33 | + |
| 34 | + |
| 35 | +class MiniMaxAttention(nn.Module): |
| 36 | + def __init__(self, args: ModelArgs): |
| 37 | + super().__init__() |
| 38 | + |
| 39 | + self.hidden_dim = hidden_size = args.hidden_size |
| 40 | + |
| 41 | + self.num_attention_heads = args.num_attention_heads |
| 42 | + self.num_key_value_heads = args.num_key_value_heads |
| 43 | + self.head_dim = head_dim = ( |
| 44 | + args.head_dim or hidden_size // args.num_attention_heads |
| 45 | + ) |
| 46 | + self.scale = head_dim**-0.5 |
| 47 | + |
| 48 | + self.q_proj = nn.Linear( |
| 49 | + args.hidden_size, self.num_attention_heads * head_dim, bias=False |
| 50 | + ) |
| 51 | + self.k_proj = nn.Linear( |
| 52 | + args.hidden_size, self.num_key_value_heads * head_dim, bias=False |
| 53 | + ) |
| 54 | + self.v_proj = nn.Linear( |
| 55 | + args.hidden_size, self.num_key_value_heads * head_dim, bias=False |
| 56 | + ) |
| 57 | + self.o_proj = nn.Linear( |
| 58 | + self.num_attention_heads * head_dim, args.hidden_size, bias=False |
| 59 | + ) |
| 60 | + |
| 61 | + self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False |
| 62 | + if self.use_qk_norm: |
| 63 | + self.q_norm = nn.RMSNorm( |
| 64 | + head_dim * self.num_attention_heads, eps=args.rms_norm_eps |
| 65 | + ) |
| 66 | + self.k_norm = nn.RMSNorm( |
| 67 | + head_dim * self.num_key_value_heads, eps=args.rms_norm_eps |
| 68 | + ) |
| 69 | + |
| 70 | + self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) |
| 71 | + |
| 72 | + def __call__( |
| 73 | + self, |
| 74 | + x: mx.array, |
| 75 | + mask: Optional[mx.array] = None, |
| 76 | + cache: Optional[Any] = None, |
| 77 | + ) -> mx.array: |
| 78 | + B, L, D = x.shape |
| 79 | + |
| 80 | + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
| 81 | + |
| 82 | + if self.use_qk_norm: |
| 83 | + queries = self.q_norm(queries) |
| 84 | + keys = self.k_norm(keys) |
| 85 | + |
| 86 | + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( |
| 87 | + 0, 2, 1, 3 |
| 88 | + ) |
| 89 | + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) |
| 90 | + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( |
| 91 | + 0, 2, 1, 3 |
| 92 | + ) |
| 93 | + |
| 94 | + if cache is not None: |
| 95 | + queries = self.rope(queries, offset=cache.offset) |
| 96 | + keys = self.rope(keys, offset=cache.offset) |
| 97 | + keys, values = cache.update_and_fetch(keys, values) |
| 98 | + else: |
| 99 | + queries = self.rope(queries) |
| 100 | + keys = self.rope(keys) |
| 101 | + |
| 102 | + output = scaled_dot_product_attention( |
| 103 | + queries, keys, values, cache=cache, scale=self.scale, mask=mask |
| 104 | + ) |
| 105 | + |
| 106 | + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) |
| 107 | + |
| 108 | + return self.o_proj(output) |
| 109 | + |
| 110 | + |
| 111 | +class MiniMaxSparseMoeBlock(nn.Module): |
| 112 | + def __init__(self, args: ModelArgs): |
| 113 | + super().__init__() |
| 114 | + self.num_experts_per_tok = args.num_experts_per_tok |
| 115 | + |
| 116 | + self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) |
| 117 | + self.switch_mlp = SwitchGLU( |
| 118 | + args.hidden_size, args.intermediate_size, args.num_local_experts |
| 119 | + ) |
| 120 | + self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) |
| 121 | + |
| 122 | + def __call__(self, x: mx.array) -> mx.array: |
| 123 | + gates = self.gate(x.astype(mx.float32)) |
| 124 | + |
| 125 | + scores = mx.sigmoid(gates) |
| 126 | + orig_scores = scores |
| 127 | + scores = scores + self.e_score_correction_bias |
| 128 | + |
| 129 | + k = self.num_experts_per_tok |
| 130 | + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] |
| 131 | + scores = mx.take_along_axis(orig_scores, inds, axis=-1) |
| 132 | + |
| 133 | + scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) |
| 134 | + scores = scores.astype(x.dtype) |
| 135 | + |
| 136 | + y = self.switch_mlp(x, inds) |
| 137 | + y = (y * scores[..., None]).sum(axis=-2) |
| 138 | + return y |
| 139 | + |
| 140 | + |
| 141 | +class MiniMaxDecoderLayer(nn.Module): |
| 142 | + def __init__(self, args: ModelArgs): |
| 143 | + super().__init__() |
| 144 | + |
| 145 | + self.self_attn = MiniMaxAttention(args) |
| 146 | + |
| 147 | + self.block_sparse_moe = MiniMaxSparseMoeBlock(args) |
| 148 | + |
| 149 | + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 150 | + self.post_attention_layernorm = nn.RMSNorm( |
| 151 | + args.hidden_size, eps=args.rms_norm_eps |
| 152 | + ) |
| 153 | + |
| 154 | + def __call__( |
| 155 | + self, |
| 156 | + x: mx.array, |
| 157 | + mask: Optional[mx.array] = None, |
| 158 | + cache: Optional[Any] = None, |
| 159 | + ) -> mx.array: |
| 160 | + r = x + self.self_attn(self.input_layernorm(x), mask, cache) |
| 161 | + r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) |
| 162 | + return r |
| 163 | + |
| 164 | + |
| 165 | +class MiniMaxModel(nn.Module): |
| 166 | + def __init__(self, args: ModelArgs): |
| 167 | + super().__init__() |
| 168 | + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) |
| 169 | + |
| 170 | + self.layers = [ |
| 171 | + MiniMaxDecoderLayer(args=args) for _ in range(args.num_hidden_layers) |
| 172 | + ] |
| 173 | + |
| 174 | + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) |
| 175 | + |
| 176 | + def __call__( |
| 177 | + self, |
| 178 | + inputs: mx.array, |
| 179 | + mask: Optional[mx.array] = None, |
| 180 | + cache: Optional[Any] = None, |
| 181 | + ) -> mx.array: |
| 182 | + h = self.embed_tokens(inputs) |
| 183 | + |
| 184 | + if cache is None: |
| 185 | + cache = [None] * len(self.layers) |
| 186 | + |
| 187 | + mask = create_attention_mask(h, cache[0]) |
| 188 | + |
| 189 | + for layer, c in zip(self.layers, cache): |
| 190 | + h = layer(h, mask, c) |
| 191 | + |
| 192 | + return self.norm(h) |
| 193 | + |
| 194 | + |
| 195 | +class Model(nn.Module): |
| 196 | + def __init__(self, args: ModelArgs): |
| 197 | + super().__init__() |
| 198 | + self.args = args |
| 199 | + self.model_type = args.model_type |
| 200 | + self.model = MiniMaxModel(args) |
| 201 | + if not args.tie_word_embeddings: |
| 202 | + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) |
| 203 | + |
| 204 | + def __call__( |
| 205 | + self, |
| 206 | + inputs: mx.array, |
| 207 | + mask: Optional[mx.array] = None, |
| 208 | + cache: Optional[Any] = None, |
| 209 | + ): |
| 210 | + out = self.model(inputs=inputs, mask=mask, cache=cache) |
| 211 | + if self.args.tie_word_embeddings: |
| 212 | + out = self.model.embed_tokens.as_linear(out) |
| 213 | + else: |
| 214 | + out = self.lm_head(out) |
| 215 | + return out |
| 216 | + |
| 217 | + def sanitize(self, weights): |
| 218 | + """Dequantize FP8 weights and restructure MoE experts.""" |
| 219 | + |
| 220 | + def dequant(weight, scale_inv): |
| 221 | + dtype = weight.dtype |
| 222 | + bs = 128 # block size |
| 223 | + m, n = weight.shape |
| 224 | + pad_bottom = (-m) % bs |
| 225 | + pad_side = (-n) % bs |
| 226 | + weight = mx.pad(weight, ((0, pad_bottom), (0, pad_side))) |
| 227 | + weight = weight.reshape( |
| 228 | + ((m + pad_bottom) // bs, bs, (n + pad_side) // bs, bs) |
| 229 | + ) |
| 230 | + weight = (weight * scale_inv[:, None, :, None]).reshape( |
| 231 | + m + pad_bottom, n + pad_side |
| 232 | + ) |
| 233 | + return weight[:m, :n].astype(dtype) |
| 234 | + |
| 235 | + # Dequantize |
| 236 | + new_weights = {} |
| 237 | + for k, v in weights.items(): |
| 238 | + if "weight_scale_inv" in k: |
| 239 | + scale_inv = v |
| 240 | + wk = k.replace("_scale_inv", "") |
| 241 | + weight = weights[wk] |
| 242 | + weight = dequant(weight, scale_inv) |
| 243 | + new_weights[wk] = weight |
| 244 | + elif k not in new_weights: |
| 245 | + new_weights[k] = v |
| 246 | + weights = new_weights |
| 247 | + |
| 248 | + # Step 2: Handle MoE expert weights restructuring |
| 249 | + if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: |
| 250 | + return weights |
| 251 | + |
| 252 | + for l in range(self.args.num_hidden_layers): |
| 253 | + prefix = f"model.layers.{l}" |
| 254 | + mapping = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"} |
| 255 | + for orig_name, new_name in mapping.items(): |
| 256 | + if f"{prefix}.block_sparse_moe.experts.0.{orig_name}.weight" in weights: |
| 257 | + to_join = [ |
| 258 | + weights.pop( |
| 259 | + f"{prefix}.block_sparse_moe.experts.{e}.{orig_name}.weight" |
| 260 | + ) |
| 261 | + for e in range(self.args.num_local_experts) |
| 262 | + ] |
| 263 | + weights[ |
| 264 | + f"{prefix}.block_sparse_moe.switch_mlp.{new_name}.weight" |
| 265 | + ] = mx.stack(to_join) |
| 266 | + |
| 267 | + return weights |
| 268 | + |
| 269 | + @property |
| 270 | + def layers(self): |
| 271 | + return self.model.layers |
| 272 | + |
| 273 | + @property |
| 274 | + def cast_predicate(self): |
| 275 | + def predicate(k): |
| 276 | + return "e_score_correction_bias" not in k |
| 277 | + |
| 278 | + return predicate |
| 279 | + |
| 280 | + @property |
| 281 | + def quant_predicate(self): |
| 282 | + def predicate(path, _): |
| 283 | + if path.endswith("block_sparse_moe.gate"): |
| 284 | + return {"group_size": 64, "bits": 8} |
| 285 | + return True |
| 286 | + |
| 287 | + return predicate |
0 commit comments