Skip to content

Commit 4a085c7

Browse files
Add lille 130m (ml-explore#429)
* in. com. * inference works * rebase * cpyrgt * upd. ackn * clean up residuals * format * rebase + nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 87961a7 commit 4a085c7

File tree

4 files changed

+187
-2
lines changed

4 files changed

+187
-2
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,17 @@ with a short description of your contribution(s) below. For example:
88
MLX LM was developed with contributions from the following individuals:
99

1010
- Shunta Saito: Added support for PLaMo models.
11-
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's `MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai & THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's `Apertus`, and Allenai's `OLMoE`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers to choose for training`, and `reporting training metrics to WandB (Weights & Biases)`.
12-
- Prince Canuma: Helped add support for the following model architectures: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`.
11+
- Gökdeniz Gülmez: Added support for the following architectures: OpenBMB's
12+
`MiniCPM` and `MiniCPM3`, Kyutai's `Helium`, State-Space's`Mamba v1`, Z.ai &
13+
THUKEG's `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's
14+
`Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`,
15+
IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's
16+
`Apertus`, Nikity's `Lille130m`, and Allenai's `OLMoE`; Added support for the
17+
following training algorithms: `Full Weight Fine-Tuning`, and the `Muon`
18+
optimizer; Added support for the following other features: `Multiple Optimizers
19+
to choose for training`, and `reporting training metrics to WandB (Weights &
20+
Biases)`.
21+
- Prince Canuma: Helped add support for the following model architectures:
22+
HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen
23+
(2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama
24+
(3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`.

mlx_lm/models/lille-130m.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Optional
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
10+
11+
12+
@dataclass
13+
class ModelArgs(BaseModelArgs):
14+
model_type: str
15+
block_size: int
16+
layer_norm_eps: float
17+
n_embd: int
18+
n_head: int
19+
n_kv_heads: int
20+
n_layer: int
21+
rope_theta: float
22+
vocab_size: int
23+
tie_word_embeddings: bool = True
24+
25+
26+
class Lille130mAttention(nn.Module):
27+
def __init__(self, args: ModelArgs):
28+
super().__init__()
29+
self.n_head = args.n_head
30+
self.n_kv_heads = args.n_kv_heads
31+
self.head_dim = args.n_embd // args.n_head
32+
self.scale = self.head_dim**-0.5
33+
34+
self.qkv_proj = nn.Linear(
35+
args.n_embd, (args.n_head + 2 * args.n_kv_heads) * self.head_dim, bias=False
36+
)
37+
self.out_proj = nn.Linear(args.n_head * self.head_dim, args.n_embd, bias=False)
38+
39+
self.norm = nn.RMSNorm(args.n_embd, eps=args.layer_norm_eps)
40+
41+
self.rope = nn.RoPE(args.n_embd // args.n_head, True, args.rope_theta)
42+
43+
def __call__(
44+
self,
45+
x: mx.array,
46+
mask: Optional[mx.array] = None,
47+
cache: Optional[Any] = None,
48+
) -> mx.array:
49+
B, L, D = x.shape
50+
51+
qkv = self.qkv_proj(self.norm(x))
52+
53+
q_size = self.n_head * self.head_dim
54+
kv_size = self.n_kv_heads * self.head_dim
55+
56+
queries, keys, values = mx.split(qkv, [q_size, q_size + kv_size], axis=-1)
57+
58+
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
59+
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
60+
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
61+
62+
if cache is not None:
63+
queries = self.rope(queries, offset=cache.offset)
64+
keys = self.rope(keys, offset=cache.offset)
65+
keys, values = cache.update_and_fetch(keys, values)
66+
else:
67+
queries = self.rope(queries)
68+
keys = self.rope(keys)
69+
70+
output = scaled_dot_product_attention(
71+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
72+
)
73+
74+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
75+
return self.out_proj(output)
76+
77+
78+
class Lille130mMLP(nn.Module):
79+
def __init__(self, args: ModelArgs):
80+
super().__init__()
81+
hidden_dim = 256 * round(int(8 * args.n_embd / 3) / 256)
82+
83+
self.norm = nn.RMSNorm(args.n_embd, eps=args.layer_norm_eps)
84+
self.gate_proj = nn.Linear(args.n_embd, hidden_dim, bias=False)
85+
self.up_proj = nn.Linear(args.n_embd, hidden_dim, bias=False)
86+
self.down_proj = nn.Linear(hidden_dim, args.n_embd, bias=False)
87+
88+
def __call__(self, x: mx.array) -> mx.array:
89+
h = self.norm(x)
90+
return self.down_proj(nn.silu(self.gate_proj(h)) * self.up_proj(h))
91+
92+
93+
class Lille130Block(nn.Module):
94+
def __init__(self, args: ModelArgs):
95+
super().__init__()
96+
self.attention = Lille130mAttention(args)
97+
self.feed_forward = Lille130mMLP(args)
98+
99+
def __call__(
100+
self,
101+
x: mx.array,
102+
mask: Optional[mx.array] = None,
103+
cache: Optional[Any] = None,
104+
) -> mx.array:
105+
h = x + self.attention(x, mask, cache)
106+
out = h + self.feed_forward(h)
107+
return out
108+
109+
110+
class Lille130(nn.Module):
111+
def __init__(self, args: ModelArgs):
112+
super().__init__()
113+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.n_embd)
114+
self.layers = [Lille130Block(args=args) for _ in range(args.n_layer)]
115+
self.norm = nn.RMSNorm(args.n_embd, eps=args.layer_norm_eps)
116+
117+
def __call__(
118+
self,
119+
inputs: mx.array,
120+
cache: Optional[Any] = None,
121+
) -> mx.array:
122+
h = self.tok_embeddings(inputs)
123+
124+
if cache is None:
125+
cache = [None] * len(self.layers)
126+
127+
mask = create_attention_mask(h, cache[0])
128+
129+
for layer, c in zip(self.layers, cache):
130+
h = layer(h, mask, cache=c)
131+
132+
return self.tok_embeddings.as_linear(self.norm(h))
133+
134+
135+
class Model(nn.Module):
136+
def __init__(self, args: ModelArgs):
137+
super().__init__()
138+
self.args = args
139+
self.model_type = args.model_type
140+
self.transformer = Lille130(args)
141+
142+
def __call__(
143+
self,
144+
inputs: mx.array,
145+
cache: Optional[Any] = None,
146+
) -> mx.array:
147+
return self.transformer(inputs, cache=cache)
148+
149+
@property
150+
def layers(self):
151+
return self.transformer.layers
152+
153+
def sanitize(self, weights):
154+
return {k: v for k, v in weights.items() if "rotary_emb" not in k}

mlx_lm/tuner/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def to_lora(layer):
129129
"seed_oss",
130130
"apertus",
131131
"Klear",
132+
"lille-130m",
132133
}:
133134
keys = {"self_attn.q_proj", "self_attn.v_proj"}
134135
if model.model_type in ["mixtral", "phimoe"]:
@@ -140,6 +141,12 @@ def to_lora(layer):
140141
keys.add("mlp.gate")
141142
if model.model_type in ["longcat_flash"]:
142143
keys.add("mlp.router.classifier")
144+
if model.model_type == "lille-130m":
145+
keys.add("attention.qkv_proj")
146+
keys.add("attention.out_proj")
147+
keys.add("feed_forward.gate_proj")
148+
keys.add("feed_forward.up_proj")
149+
keys.add("feed_forward.down_proj")
143150
elif model.model_type == "gpt_bigcode":
144151
keys = {"attn.c_attn"}
145152
elif model.model_type == "gpt2":

tests/test_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,18 @@ def test_all_models(self):
16711671
"max_position_embeddings": 1000,
16721672
"norm_topk_prob": True,
16731673
},
1674+
{
1675+
"model_type": "lille-130m",
1676+
"block_size": 128,
1677+
"num_hidden_layers": 4,
1678+
"n_layer": 4,
1679+
"n_head": 4,
1680+
"n_kv_heads": 4,
1681+
"n_embd": 128,
1682+
"vocab_size": 1000,
1683+
"rope_theta": 1000,
1684+
"layer_norm_eps": 1e-5,
1685+
},
16741686
]
16751687
for config in test_configs:
16761688
model_type = config["model_type"]

0 commit comments

Comments
 (0)