Skip to content

Commit b12ef41

Browse files
authored
Port of nanochat (#415)
* port of nanochat
1 parent 8775112 commit b12ef41

File tree

2 files changed

+290
-0
lines changed

2 files changed

+290
-0
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
6565
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
6666
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
6767
"lfm2_moe": create(LFM2MoEConfiguration.self, LFM2MoEModel.init),
68+
"nanochat": create(NanoChatConfiguration.self, NanoChatModel.init),
6869
]
6970
}
7071
}
@@ -334,6 +335,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
334335
defaultPrompt: ""
335336
)
336337

338+
static public let nanochat_d20_mlx = ModelConfiguration(
339+
id: "dnakov/nanochat-d20-mlx",
340+
defaultPrompt: ""
341+
)
342+
337343
private static func all() -> [ModelConfiguration] {
338344
[
339345
codeLlama13b4bit,
@@ -382,6 +388,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
382388
olmo_2_1124_7B_Instruct_4bit,
383389
ling_mini_2_2bit,
384390
lfm2_8b_a1b_3bit_mlx,
391+
nanochat_d20_mlx,
385392
]
386393
}
387394

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
//
2+
// NanoChat.swift
3+
// mlx-swift-examples
4+
//
5+
// Created by Sachin Desai 10/15/25.
6+
//
7+
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/nanochat.py
8+
//
9+
10+
import Foundation
11+
import MLX
12+
import MLXFast
13+
import MLXLMCommon
14+
import MLXNN
15+
16+
// MARK: - Helpers
17+
18+
private func functionalRMSNorm(_ x: MLXArray, eps: Float) -> MLXArray {
19+
let meanSquares = mean(x.square(), axis: -1, keepDims: true)
20+
return x * (meanSquares + eps).rsqrt()
21+
}
22+
23+
private func applySoftcap(_ logits: MLXArray, cap: Float) -> MLXArray {
24+
guard cap > 0 else { return logits }
25+
let scale = MLXArray(cap)
26+
return scale * tanh(logits / scale)
27+
}
28+
29+
// MARK: - Attention
30+
31+
private final class NanoChatAttention: Module {
32+
let config: NanoChatConfiguration
33+
let numHeads: Int
34+
let numKVHeads: Int
35+
let headDim: Int
36+
let scale: Float
37+
38+
@ModuleInfo(key: "c_q") var wq: Linear
39+
@ModuleInfo(key: "c_k") var wk: Linear
40+
@ModuleInfo(key: "c_v") var wv: Linear
41+
@ModuleInfo(key: "c_proj") var wo: Linear
42+
43+
private let _ropeFreqs: MLXArray
44+
45+
init(_ config: NanoChatConfiguration) {
46+
self.config = config
47+
self.numHeads = config.attentionHeads
48+
self.numKVHeads = config.kvHeads
49+
self.headDim = config.hiddenSize / config.attentionHeads
50+
precondition(headDim % 2 == 0, "Head dimension must be even for rotary embeddings.")
51+
52+
self.scale = pow(Float(headDim), -0.5)
53+
54+
_wq.wrappedValue = Linear(config.hiddenSize, numHeads * headDim, bias: false)
55+
_wk.wrappedValue = Linear(config.hiddenSize, numKVHeads * headDim, bias: false)
56+
_wv.wrappedValue = Linear(config.hiddenSize, numKVHeads * headDim, bias: false)
57+
_wo.wrappedValue = Linear(numHeads * headDim, config.hiddenSize, bias: false)
58+
59+
let halfDim = headDim / 2
60+
let freqIndices = MLXArray(Array(0 ..< halfDim)).asType(.float32)
61+
let freqScale = Float(log(Double(config.ropeTheta)) / Double(halfDim))
62+
self._ropeFreqs = -MLX.exp(freqIndices * freqScale)
63+
}
64+
65+
func callAsFunction(
66+
_ x: MLXArray,
67+
mask: MLXFast.ScaledDotProductAttentionMaskMode,
68+
cache: KVCache?
69+
) -> MLXArray {
70+
let (batchSize, sequenceLength) = (x.dim(0), x.dim(1))
71+
72+
var queries = wq(x)
73+
var keys = wk(x)
74+
var values = wv(x)
75+
76+
queries = queries.reshaped(batchSize, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3)
77+
keys = keys.reshaped(batchSize, sequenceLength, numKVHeads, -1).transposed(0, 2, 1, 3)
78+
values = values.reshaped(batchSize, sequenceLength, numKVHeads, -1).transposed(0, 2, 1, 3)
79+
80+
let offset = cache?.offset ?? 0
81+
let freqs = _ropeFreqs
82+
queries = MLXFast.RoPE(
83+
queries,
84+
dimensions: headDim,
85+
traditional: false,
86+
base: nil,
87+
scale: 1.0,
88+
offset: offset,
89+
freqs: freqs
90+
)
91+
keys = MLXFast.RoPE(
92+
keys,
93+
dimensions: headDim,
94+
traditional: false,
95+
base: nil,
96+
scale: 1.0,
97+
offset: offset,
98+
freqs: freqs
99+
)
100+
101+
queries = functionalRMSNorm(queries, eps: config.rmsNormEps)
102+
keys = functionalRMSNorm(keys, eps: config.rmsNormEps)
103+
104+
let output = attentionWithCacheUpdate(
105+
queries: queries,
106+
keys: keys,
107+
values: values,
108+
cache: cache,
109+
scale: scale,
110+
mask: mask
111+
)
112+
.transposed(0, 2, 1, 3)
113+
.reshaped(batchSize, sequenceLength, -1)
114+
115+
return wo(output)
116+
}
117+
}
118+
119+
// MARK: - MLP
120+
121+
private final class NanoChatMLP: Module, UnaryLayer {
122+
let config: NanoChatConfiguration
123+
124+
@ModuleInfo(key: "c_fc") var fc: Linear
125+
@ModuleInfo(key: "c_proj") var proj: Linear
126+
127+
init(_ config: NanoChatConfiguration) {
128+
self.config = config
129+
_fc.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false)
130+
_proj.wrappedValue = Linear(config.intermediateSize, config.hiddenSize, bias: false)
131+
}
132+
133+
func callAsFunction(_ x: MLXArray) -> MLXArray {
134+
let activated = relu(fc(x))
135+
return proj(activated * activated)
136+
}
137+
}
138+
139+
// MARK: - Transformer Block
140+
141+
private final class NanoChatBlock: Module {
142+
let config: NanoChatConfiguration
143+
144+
@ModuleInfo(key: "attn") var attention: NanoChatAttention
145+
@ModuleInfo(key: "mlp") var mlp: NanoChatMLP
146+
147+
init(_ config: NanoChatConfiguration) {
148+
self.config = config
149+
_attention.wrappedValue = NanoChatAttention(config)
150+
_mlp.wrappedValue = NanoChatMLP(config)
151+
}
152+
153+
func callAsFunction(
154+
_ x: MLXArray,
155+
mask: MLXFast.ScaledDotProductAttentionMaskMode,
156+
cache: KVCache?
157+
) -> MLXArray {
158+
let attnOutput = attention(
159+
functionalRMSNorm(x, eps: config.rmsNormEps), mask: mask, cache: cache)
160+
let residual = x + attnOutput
161+
let mlpOutput = mlp(functionalRMSNorm(residual, eps: config.rmsNormEps))
162+
return residual + mlpOutput
163+
}
164+
}
165+
166+
// MARK: - Model (inner)
167+
168+
private final class NanoChatModelInner: Module {
169+
let config: NanoChatConfiguration
170+
171+
@ModuleInfo(key: "wte") var embedTokens: Embedding
172+
@ModuleInfo(key: "h") var layers: [NanoChatBlock]
173+
174+
init(_ config: NanoChatConfiguration) {
175+
precondition(config.vocabularySize > 0)
176+
self.config = config
177+
178+
_embedTokens.wrappedValue = Embedding(
179+
embeddingCount: config.vocabularySize,
180+
dimensions: config.hiddenSize
181+
)
182+
_layers.wrappedValue = (0 ..< config.hiddenLayers).map { _ in NanoChatBlock(config) }
183+
}
184+
185+
func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
186+
var hidden = embedTokens(inputs)
187+
hidden = functionalRMSNorm(hidden, eps: config.rmsNormEps)
188+
189+
let mask = createAttentionMask(h: hidden, cache: cache)
190+
191+
for (index, layer) in layers.enumerated() {
192+
hidden = layer(hidden, mask: mask, cache: cache?[index])
193+
}
194+
195+
return functionalRMSNorm(hidden, eps: config.rmsNormEps)
196+
}
197+
}
198+
199+
// MARK: - Public Model
200+
201+
public final class NanoChatModel: Module, LLMModel, KVCacheDimensionProvider {
202+
public let vocabularySize: Int
203+
public let kvHeads: [Int]
204+
public let modelType: String
205+
206+
let config: NanoChatConfiguration
207+
208+
@ModuleInfo(key: "transformer") fileprivate var transformer: NanoChatModelInner
209+
@ModuleInfo(key: "lm_head") var lmHead: Linear
210+
211+
public init(_ config: NanoChatConfiguration) {
212+
self.config = config
213+
self.modelType = config.modelType
214+
self.vocabularySize = config.vocabularySize
215+
self.kvHeads = Array(repeating: config.kvHeads, count: config.hiddenLayers)
216+
217+
_transformer.wrappedValue = NanoChatModelInner(config)
218+
_lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false)
219+
}
220+
221+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
222+
let hidden = transformer(inputs, cache: cache)
223+
let logits = lmHead(hidden)
224+
return applySoftcap(logits, cap: config.logitsSoftcap)
225+
}
226+
}
227+
228+
// MARK: - Configuration
229+
230+
public struct NanoChatConfiguration: Codable, Sendable {
231+
public var modelType: String
232+
public var hiddenSize: Int
233+
public var hiddenLayers: Int
234+
public var attentionHeads: Int
235+
public var kvHeads: Int
236+
public var vocabularySize: Int
237+
public var maxPositionEmbeddings: Int
238+
public var intermediateSize: Int
239+
public var ropeTheta: Float
240+
public var rmsNormEps: Float
241+
public var logitsSoftcap: Float
242+
243+
enum CodingKeys: String, CodingKey {
244+
case modelType = "model_type"
245+
case hiddenSize = "hidden_size"
246+
case hiddenLayers = "num_hidden_layers"
247+
case attentionHeads = "num_attention_heads"
248+
case kvHeads = "num_key_value_heads"
249+
case vocabularySize = "vocab_size"
250+
case maxPositionEmbeddings = "max_position_embeddings"
251+
case intermediateSize = "intermediate_size"
252+
case ropeTheta = "rope_theta"
253+
case rmsNormEps = "rms_norm_eps"
254+
case logitsSoftcap = "logits_softcap"
255+
}
256+
257+
public init(from decoder: Decoder) throws {
258+
let container = try decoder.container(keyedBy: CodingKeys.self)
259+
260+
self.modelType =
261+
try container.decodeIfPresent(String.self, forKey: .modelType) ?? "nanochat"
262+
self.hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
263+
self.hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
264+
self.attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
265+
self.kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
266+
self.vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
267+
self.maxPositionEmbeddings = try container.decode(
268+
Int.self, forKey: .maxPositionEmbeddings)
269+
self.intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
270+
self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10_000
271+
self.rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-5
272+
self.logitsSoftcap =
273+
try container.decodeIfPresent(Float.self, forKey: .logitsSoftcap) ?? 15.0
274+
}
275+
}
276+
277+
// MARK: - LoRA
278+
279+
extension NanoChatModel: LoRAModel {
280+
public func loraLinearLayers() -> LoRALinearLayers {
281+
transformer.layers.map { ($0.attention, ["c_q", "c_v"]) }
282+
}
283+
}

0 commit comments

Comments
 (0)