Skip to content

Commit 87961a7

Browse files
adding Kwai-Klear/Klear-46B-A2.5B-Instruct (ml-explore#437)
* in. com. * clean up * sanitize * fix * nits * making it trainable * format * upd. ackn * rebase + nits * rebase + nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent d79f3cd commit 87961a7

File tree

4 files changed

+285
-2
lines changed

4 files changed

+285
-2
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ with a short description of your contribution(s) below. For example:
88
MLX LM was developed with contributions from the following individuals:
99

1010
- Shunta Saito: Added support for PLaMo models.
11-
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, and Allenai's `OLMoE`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`.
11+
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, and Allenai's `OLMoE`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`.
1212
- Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`.

mlx_lm/models/Klear.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Any, List, Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
from .switch_layers import SwitchGLU
11+
12+
13+
@dataclass
14+
class ModelArgs(BaseModelArgs):
15+
model_type: str
16+
hidden_size: int
17+
num_hidden_layers: int
18+
intermediate_size: int
19+
num_attention_heads: int
20+
attention_bias: bool
21+
mlp_only_layers: List[int]
22+
num_experts: int
23+
num_experts_per_tok: int
24+
decoder_sparse_step: int
25+
n_shared_experts: int
26+
moe_intermediate_size: int
27+
rms_norm_eps: float
28+
vocab_size: int
29+
num_key_value_heads: int
30+
rope_theta: float
31+
max_position_embeddings: int
32+
norm_topk_prob: bool
33+
34+
35+
class KlearAttention(nn.Module):
36+
def __init__(self, args: ModelArgs):
37+
super().__init__()
38+
self.num_attention_heads = args.num_attention_heads
39+
self.num_key_value_heads = args.num_key_value_heads
40+
41+
self.head_dim = args.hidden_size // args.num_attention_heads
42+
self.scale = self.head_dim**-0.5
43+
44+
self.q_proj = nn.Linear(
45+
args.hidden_size,
46+
self.num_attention_heads * self.head_dim,
47+
bias=args.attention_bias,
48+
)
49+
self.k_proj = nn.Linear(
50+
args.hidden_size,
51+
self.num_key_value_heads * self.head_dim,
52+
bias=args.attention_bias,
53+
)
54+
self.v_proj = nn.Linear(
55+
args.hidden_size,
56+
self.num_key_value_heads * self.head_dim,
57+
bias=args.attention_bias,
58+
)
59+
self.o_proj = nn.Linear(
60+
self.num_attention_heads * self.head_dim,
61+
args.hidden_size,
62+
bias=args.attention_bias,
63+
)
64+
65+
self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps)
66+
self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps)
67+
68+
self.rope = nn.RoPE(
69+
self.head_dim,
70+
traditional=False,
71+
base=args.rope_theta,
72+
)
73+
74+
def __call__(
75+
self,
76+
x: mx.array,
77+
mask: Optional[mx.array] = None,
78+
cache: Optional[Any] = None,
79+
) -> mx.array:
80+
B, L, D = x.shape
81+
82+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
83+
84+
queries = self.q_norm(
85+
queries.reshape(B, L, self.num_attention_heads, -1)
86+
).transpose(0, 2, 1, 3)
87+
keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose(
88+
0, 2, 1, 3
89+
)
90+
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
91+
0, 2, 1, 3
92+
)
93+
94+
if cache is not None:
95+
queries = self.rope(queries, offset=cache.offset)
96+
keys = self.rope(keys, offset=cache.offset)
97+
keys, values = cache.update_and_fetch(keys, values)
98+
else:
99+
queries = self.rope(queries)
100+
keys = self.rope(keys)
101+
102+
output = scaled_dot_product_attention(
103+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
104+
)
105+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
106+
return self.o_proj(output)
107+
108+
109+
class KlearMLP(nn.Module):
110+
def __init__(self, dim, hidden_dim):
111+
super().__init__()
112+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
113+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
114+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
115+
116+
def __call__(self, x) -> mx.array:
117+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
118+
119+
120+
class KlearSparseMoeBlock(nn.Module):
121+
def __init__(self, args: ModelArgs):
122+
super().__init__()
123+
self.norm_topk_prob = args.norm_topk_prob
124+
self.num_experts = args.num_experts
125+
self.top_k = args.num_experts_per_tok
126+
127+
self.gate = nn.Linear(args.hidden_size, args.num_experts, bias=False)
128+
self.experts = SwitchGLU(
129+
args.hidden_size, args.moe_intermediate_size, args.num_experts
130+
)
131+
self.shared_experts = KlearMLP(
132+
args.hidden_size,
133+
hidden_dim=args.moe_intermediate_size * args.n_shared_experts,
134+
)
135+
self.coefficient = nn.Linear(args.hidden_size, 2)
136+
self.expert_bias = mx.zeros((self.num_experts,), dtype=mx.float32)
137+
138+
def __call__(self, x: mx.array) -> mx.array:
139+
routing_weights = mx.sigmoid(self.gate(x).astype(mx.float32))
140+
biased_weights = routing_weights + self.expert_bias.reshape((1, 1, -1))
141+
k = self.top_k
142+
inds = mx.argpartition(-biased_weights, kth=k - 1, axis=-1)[..., :k]
143+
scores = mx.take_along_axis(routing_weights, inds, axis=-1)
144+
if self.norm_topk_prob:
145+
scores = scores / mx.sum(scores, axis=-1, keepdims=True)
146+
scores = scores.astype(x.dtype)
147+
expert_out = self.experts(x, inds)
148+
y_experts = (expert_out * scores[..., None]).sum(axis=-2)
149+
coef = mx.softmax(self.coefficient(x), axis=-1, precise=True)
150+
shared = self.shared_experts(x)
151+
y = y_experts * coef[..., :1] + shared * coef[..., 1:]
152+
return y
153+
154+
155+
class KlearDecoderLayer(nn.Module):
156+
def __init__(self, args: ModelArgs, layer_idx: int):
157+
super().__init__()
158+
self.self_attn = KlearAttention(args)
159+
160+
if (layer_idx not in args.mlp_only_layers) and (
161+
args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0
162+
):
163+
self.mlp = KlearSparseMoeBlock(args)
164+
else:
165+
self.mlp = KlearMLP(args.hidden_size, args.intermediate_size)
166+
167+
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
168+
self.post_attention_layernorm = nn.RMSNorm(
169+
args.hidden_size, eps=args.rms_norm_eps
170+
)
171+
172+
def __call__(
173+
self,
174+
x: mx.array,
175+
mask: Optional[mx.array] = None,
176+
cache: Optional[Any] = None,
177+
) -> mx.array:
178+
r = self.self_attn(self.input_layernorm(x), mask, cache)
179+
h = x + r
180+
r = self.mlp(self.post_attention_layernorm(h))
181+
out = h + r
182+
return out
183+
184+
185+
class KlearModel(nn.Module):
186+
def __init__(self, args: ModelArgs):
187+
super().__init__()
188+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
189+
self.layers = [
190+
KlearDecoderLayer(args=args, layer_idx=i)
191+
for i in range(args.num_hidden_layers)
192+
]
193+
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
194+
195+
def __call__(
196+
self,
197+
inputs: mx.array,
198+
cache: Optional[Any] = None,
199+
) -> mx.array:
200+
h = self.embed_tokens(inputs)
201+
202+
if cache is None:
203+
cache = [None] * len(self.layers)
204+
205+
mask = create_attention_mask(h, cache[0])
206+
207+
for layer, c in zip(self.layers, cache):
208+
h = layer(h, mask, c)
209+
210+
return self.norm(h)
211+
212+
213+
class Model(nn.Module):
214+
def __init__(self, args: ModelArgs):
215+
super().__init__()
216+
self.args = args
217+
self.model_type = args.model_type
218+
self.model = KlearModel(args)
219+
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
220+
221+
def __call__(
222+
self,
223+
inputs: mx.array,
224+
cache: Optional[Any] = None,
225+
):
226+
out = self.model(inputs, cache)
227+
return self.lm_head(out)
228+
229+
def sanitize(self, weights):
230+
if "model.layers.0.mlp.experts.0.gate_proj.weight" not in weights:
231+
return weights
232+
233+
for l in range(self.args.num_hidden_layers):
234+
prefix = f"model.layers.{l}.mlp.experts"
235+
for name in ["gate_proj", "up_proj", "down_proj"]:
236+
stacked = [
237+
weights.pop(f"{prefix}.{e}.{name}.weight")
238+
for e in range(self.args.num_experts)
239+
]
240+
weights[f"{prefix}.{name}.weight"] = mx.stack(stacked)
241+
242+
return weights
243+
244+
@property
245+
def layers(self):
246+
return self.model.layers
247+
248+
@property
249+
def quant_predicate(self):
250+
def predicate(path, _):
251+
if path.endswith("mlp.gate"):
252+
return {"group_size": 64, "bits": 8}
253+
return True
254+
255+
return predicate
256+
257+
@property
258+
def cast_predicate(self):
259+
def predicate(k):
260+
return "expert_bias" not in k
261+
262+
return predicate

mlx_lm/tuner/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,15 @@ def to_lora(layer):
128128
"longcat_flash",
129129
"seed_oss",
130130
"apertus",
131+
"Klear",
131132
}:
132133
keys = {"self_attn.q_proj", "self_attn.v_proj"}
133134
if model.model_type in ["mixtral", "phimoe"]:
134135
keys.add("block_sparse_moe.gate")
135136
if model.model_type == "qwen2_moe":
136137
keys.add("mlp.gate")
137138
keys.add("mlp.shared_expert_gate")
138-
if model.model_type in ["olmoe", "qwen3_moe", "dots1"]:
139+
if model.model_type in ["olmoe", "qwen3_moe", "dots1", "Klear"]:
139140
keys.add("mlp.gate")
140141
if model.model_type in ["longcat_flash"]:
141142
keys.add("mlp.router.classifier")

tests/test_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,26 @@ def test_all_models(self):
16511651
"num_key_value_heads": 2,
16521652
"head_dim": 64,
16531653
},
1654+
{
1655+
"model_type": "Klear",
1656+
"hidden_size": 128,
1657+
"num_hidden_layers": 4,
1658+
"intermediate_size": 128,
1659+
"num_attention_heads": 4,
1660+
"attention_bias": False,
1661+
"mlp_only_layers": [0],
1662+
"num_experts": 4,
1663+
"num_experts_per_tok": 2,
1664+
"decoder_sparse_step": 2,
1665+
"n_shared_experts": 1,
1666+
"moe_intermediate_size": 128,
1667+
"rms_norm_eps": 1e-5,
1668+
"vocab_size": 1000,
1669+
"num_key_value_heads": 4,
1670+
"rope_theta": 1000.0,
1671+
"max_position_embeddings": 1000,
1672+
"norm_topk_prob": True,
1673+
},
16541674
]
16551675
for config in test_configs:
16561676
model_type = config["model_type"]

0 commit comments

Comments
 (0)