Skip to content

Commit 4d20785

Browse files
authored
add support for OpenELM (#63)
* add support for OpenELM * register model configuration for bootstrap
1 parent dfd79d0 commit 4d20785

File tree

4 files changed

+332
-0
lines changed

4 files changed

+332
-0
lines changed

Libraries/LLM/Configuration.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public enum ModelType: String, Codable {
3535
case qwen2
3636
case starcoder2
3737
case cohere
38+
case openelm
3839

3940
public func createModel(configuration: URL) throws -> LLMModel {
4041
switch self {
@@ -66,6 +67,10 @@ public enum ModelType: String, Codable {
6667
let configuration = try JSONDecoder().decode(
6768
CohereConfiguration.self, from: Data(contentsOf: configuration))
6869
return CohereModel(configuration)
70+
case .openelm:
71+
let configuration = try JSONDecoder().decode(
72+
OpenElmConfiguration.self, from: Data(contentsOf: configuration))
73+
return OpenELMModel(configuration)
6974
}
7075
}
7176
}

Libraries/LLM/Models.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ extension ModelConfiguration {
137137
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
138138
}
139139

140+
public static let openelm270m4bit = ModelConfiguration(
141+
id: "mlx-community/OpenELM-270M-Instruct"
142+
) { prompt in
143+
"\(prompt)"
144+
}
145+
140146
private enum BootstrapState {
141147
case idle
142148
case bootstrapping
@@ -156,6 +162,7 @@ extension ModelConfiguration {
156162
phi34bit,
157163
gemma2bQuantized,
158164
qwen205b4bit,
165+
openelm270m4bit,
159166
])
160167
bootstrapState = .bootstrapped
161168

Libraries/LLM/OpenELM.swift

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
//
2+
// OpenELM.swift
3+
// LLM
4+
//
5+
// Created by Sachin Desai on 2024/4/27.
6+
//
7+
8+
import Foundation
9+
import MLX
10+
import MLXFast
11+
import MLXNN
12+
13+
func computeHeads(modelDim: Int, headDim: Int) -> Int {
14+
assert(modelDim % headDim == 0, "modelDim must be divisible by headDim")
15+
return modelDim / headDim
16+
}
17+
18+
func makeDivisible(_ v: Float, divisor: Int = 8, minValue: Float? = nil) -> Int {
19+
let minVal = minValue ?? Float(divisor)
20+
var roundDown = max(minVal, Float(Int((v + Float(divisor) / 2) / Float(divisor)) * divisor))
21+
22+
if roundDown < 0.9 * v {
23+
roundDown += Float(divisor)
24+
}
25+
return Int(roundDown)
26+
}
27+
28+
private class MultiHeadCausalAttention: Module {
29+
var args: OpenElmConfiguration
30+
let scale: Float
31+
let heads: Int
32+
let headDim: Int
33+
let kvHeads: Int
34+
35+
@ModuleInfo(key: "qkv_proj") var qkvProj: Linear
36+
@ModuleInfo(key: "out_proj") var outProj: Linear
37+
38+
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
39+
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
40+
41+
let rope: RoPE
42+
43+
public init(_ args: OpenElmConfiguration, layerId: Int) {
44+
self.args = args
45+
self.headDim = args.headDimensions
46+
let modelDim = args.modelDim
47+
48+
self.heads = self.args.numQueryHeads[layerId]
49+
self.kvHeads = self.args.kvHeads[layerId]
50+
self.scale = pow(Float(headDim), -0.5)
51+
52+
let opSize = (heads + (kvHeads * 2)) * headDim
53+
self._qkvProj.wrappedValue = Linear(modelDim, opSize, bias: false)
54+
self._outProj.wrappedValue = Linear(heads * headDim, modelDim, bias: false)
55+
56+
if args.normalizeQkProjections {
57+
self._qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
58+
self._kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
59+
}
60+
61+
self.rope = RoPE(
62+
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
63+
}
64+
65+
public func callAsFunction(
66+
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
67+
) -> (MLXArray, (MLXArray, MLXArray)) {
68+
let (B, L) = (x.dim(0), x.dim(1))
69+
let qkv = qkvProj(x).reshaped(B, L, heads + (kvHeads * 2), headDim).transposed(0, 2, 1, 3)
70+
71+
let qkvSplit = split(qkv, indices: [heads, heads + kvHeads], axis: 1)
72+
var queries = qkvSplit[0]
73+
var keys = qkvSplit[1]
74+
var values = qkvSplit[2]
75+
76+
if args.normalizeQkProjections {
77+
queries = qNorm(queries)
78+
keys = kNorm(keys)
79+
}
80+
81+
if let (keyCache, valueCache) = cache {
82+
queries = rope(queries, offset: keyCache.dim(2))
83+
keys = rope(keys, offset: keyCache.dim(2))
84+
keys = concatenated([keyCache, keys], axis: 2)
85+
values = concatenated([valueCache, values], axis: 2)
86+
} else {
87+
queries = rope(queries)
88+
keys = rope(keys)
89+
}
90+
91+
let output = MLXFast.scaledDotProductAttention(
92+
queries: queries, keys: keys, values: values, scale: scale, mask: mask
93+
).transposed(0, 2, 1, 3).reshaped(B, L, heads * headDim)
94+
95+
return (outProj(output), (keys, values))
96+
}
97+
}
98+
99+
private class FeedForwardNetwork: Module, UnaryLayer {
100+
@ModuleInfo var proj_1: Linear
101+
@ModuleInfo var proj_2: Linear
102+
103+
public init(_ args: OpenElmConfiguration, layedId: Int) {
104+
let dim = args.modelDim
105+
let ffnMultiplier = args.ffnMultipliers[layedId]
106+
let intermediateDim = Int(
107+
makeDivisible(Float(ffnMultiplier) * Float(dim), divisor: args.ffnDimDivisor))
108+
109+
self.proj_1 = Linear(dim, 2 * intermediateDim)
110+
self.proj_2 = Linear(intermediateDim, dim)
111+
}
112+
113+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
114+
let a = proj_1(x)
115+
let b = split(a, parts: 2, axis: -1)
116+
let gate = b[0]
117+
let x = b[1]
118+
return proj_2(silu(gate) * x)
119+
}
120+
}
121+
122+
private class TransformerDecoderLayer: Module {
123+
@ModuleInfo(key: "attn") var attn: MultiHeadCausalAttention
124+
let ffn: FeedForwardNetwork
125+
126+
@ModuleInfo(key: "ffn_norm") var ffnNorm: RMSNorm
127+
@ModuleInfo(key: "attn_norm") var attnNorm: RMSNorm
128+
129+
public init(_ args: OpenElmConfiguration, layerId: Int) {
130+
let dim = args.modelDim
131+
self._attn.wrappedValue = MultiHeadCausalAttention(args, layerId: layerId)
132+
self.ffn = FeedForwardNetwork(args, layedId: layerId)
133+
self._ffnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps)
134+
self._attnNorm.wrappedValue = RMSNorm(dimensions: dim, eps: args.rmsNormEps)
135+
}
136+
137+
public func callAsFunction(
138+
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
139+
) -> (MLXArray, (MLXArray, MLXArray)) {
140+
var (r, cache) = attn(attnNorm(x), mask: mask, cache: cache)
141+
let h = x + r
142+
r = ffn(ffnNorm(h))
143+
let out = h + r
144+
return (out, cache)
145+
}
146+
}
147+
148+
class OpenELMModelInner: Module, LLMModel {
149+
var vocabularySize: Int
150+
151+
@ModuleInfo(key: "token_embeddings") var embedTokens: Embedding
152+
153+
fileprivate let layers: [TransformerDecoderLayer]
154+
fileprivate let norm: RMSNorm
155+
156+
public init(_ args: OpenElmConfiguration) {
157+
precondition(args.vocabularySize > 0)
158+
159+
self.vocabularySize = args.vocabularySize
160+
self._embedTokens.wrappedValue = Embedding(
161+
embeddingCount: self.vocabularySize, dimensions: args.modelDim)
162+
163+
self.layers = (0 ..< args.numTransformerLayers)
164+
.map { layerId in
165+
TransformerDecoderLayer(args, layerId: layerId)
166+
}
167+
168+
self.norm = RMSNorm(dimensions: args.modelDim, eps: args.rmsNormEps)
169+
}
170+
171+
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
172+
MLXArray, [(MLXArray, MLXArray)]
173+
) {
174+
var h = embedTokens(inputs)
175+
var mask: MLXArray? = nil
176+
if h.dim(1) > 1 {
177+
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
178+
mask = mask?.asType(h.dtype)
179+
}
180+
181+
var newCache = [(MLXArray, MLXArray)]()
182+
for (i, layer) in layers.enumerated() {
183+
var cacheUpdate: (MLXArray, MLXArray)
184+
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
185+
newCache.append(cacheUpdate)
186+
}
187+
188+
return (norm(h), newCache)
189+
}
190+
}
191+
192+
public class OpenELMModel: Module, LLMModel {
193+
public let vocabularySize: Int
194+
let shareInputOutputLayers: Bool
195+
let transformer: OpenELMModelInner
196+
197+
@ModuleInfo(key: "lm_head") var lmHead: Linear
198+
199+
public init(_ args: OpenElmConfiguration) {
200+
self.vocabularySize = args.vocabularySize
201+
self.transformer = OpenELMModelInner(args)
202+
self.shareInputOutputLayers = args.shareInputOutputLayers
203+
self._lmHead.wrappedValue = Linear(
204+
args.numTransformerLayers, args.vocabularySize, bias: false)
205+
}
206+
207+
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
208+
MLXArray, [(MLXArray, MLXArray)]
209+
) {
210+
var (out, cache) = transformer(inputs, cache: cache)
211+
if shareInputOutputLayers {
212+
out = matmul(out, transformer.embedTokens.weight.T)
213+
} else {
214+
out = lmHead(out)
215+
}
216+
217+
return (out, cache)
218+
}
219+
}
220+
221+
public struct OpenElmConfiguration: Codable {
222+
var modelType: String
223+
var headDimensions: Int
224+
var numTransformerLayers: Int
225+
var modelDim: Int
226+
var vocabularySize: Int
227+
var ffnDimDivisor: Int
228+
var numQueryHeads: [Int] = []
229+
var kvHeads: [Int] = []
230+
var ffnWithGlu: Bool = true
231+
var normalizeQkProjections: Bool = true
232+
var shareInputOutputLayers: Bool = true
233+
var rmsNormEps: Float = 1e-6
234+
var ropeTheta: Float = 10_000
235+
var ropeTraditional: Bool = false
236+
var numGqaGroups: Int = 4
237+
var ffnMultipliers: [Float] = [0.5, 4.0]
238+
var qkvMultiplier: [Float] = [0.5, 1.0]
239+
240+
enum CodingKeys: String, CodingKey {
241+
case modelType = "model_type"
242+
case headDimensions = "head_dim"
243+
case numTransformerLayers = "num_transformer_layers"
244+
case modelDim = "model_dim"
245+
case vocabularySize = "vocab_size"
246+
case ffnDimDivisor = "ffn_dim_divisor"
247+
case ffnMultipliers = "ffn_multipliers"
248+
case ffnWithGlu = "ffn_with_glu"
249+
case normalizeQkProjections = "normalize_qk_projections"
250+
case shareInputOutputLayers = "share_input_output_layers"
251+
}
252+
253+
public init(from decoder: Decoder) throws {
254+
// custom implementation to handle optional keys with required values
255+
let container: KeyedDecodingContainer<OpenElmConfiguration.CodingKeys> =
256+
try decoder.container(
257+
keyedBy: OpenElmConfiguration.CodingKeys.self)
258+
259+
self.modelType = try container.decode(
260+
String.self, forKey: OpenElmConfiguration.CodingKeys.modelType)
261+
self.headDimensions = try container.decode(
262+
Int.self, forKey: OpenElmConfiguration.CodingKeys.headDimensions)
263+
self.numTransformerLayers = try container.decode(
264+
Int.self, forKey: OpenElmConfiguration.CodingKeys.numTransformerLayers)
265+
266+
self.modelDim = try container.decode(
267+
Int.self, forKey: OpenElmConfiguration.CodingKeys.modelDim)
268+
self.vocabularySize = try container.decode(
269+
Int.self, forKey: OpenElmConfiguration.CodingKeys.vocabularySize)
270+
self.ffnDimDivisor = try container.decode(
271+
Int.self, forKey: OpenElmConfiguration.CodingKeys.ffnDimDivisor)
272+
273+
let qkvMultipliers = stride(
274+
from: qkvMultiplier[0], through: qkvMultiplier[1],
275+
by: (qkvMultiplier[1] - qkvMultiplier[0]) / Float(numTransformerLayers - 1)
276+
)
277+
.map { round($0 * 100) / 100 }
278+
279+
let headMultipleOf = numGqaGroups
280+
let queryDims = qkvMultipliers.map { a in
281+
makeDivisible(Float(self.modelDim) * a, divisor: self.headDimensions * headMultipleOf)
282+
}
283+
284+
self.numQueryHeads = queryDims.map { qDim in
285+
Int(computeHeads(modelDim: qDim, headDim: self.headDimensions))
286+
}
287+
288+
self.kvHeads = self.numQueryHeads.map { qHeads in
289+
qHeads / numGqaGroups
290+
}
291+
292+
self.ffnMultipliers = stride(
293+
from: ffnMultipliers[0], through: ffnMultipliers[1],
294+
by: (ffnMultipliers[1] - ffnMultipliers[0]) / Float(numTransformerLayers - 1)
295+
)
296+
.map { round($0 * 100) / 100 }
297+
298+
self.ffnWithGlu =
299+
try container.decodeIfPresent(
300+
Bool.self, forKey: OpenElmConfiguration.CodingKeys.ffnWithGlu) ?? true
301+
self.normalizeQkProjections =
302+
try container.decodeIfPresent(
303+
Bool.self, forKey: OpenElmConfiguration.CodingKeys.normalizeQkProjections) ?? true
304+
self.shareInputOutputLayers =
305+
try container.decodeIfPresent(
306+
Bool.self, forKey: OpenElmConfiguration.CodingKeys.shareInputOutputLayers) ?? true
307+
}
308+
}
309+
310+
// MARK: - LoRA
311+
312+
extension OpenELMModel: LoRAModel {
313+
public func loraLinearLayers() -> LoRALinearLayers {
314+
transformer.layers.map { ($0.attn, ["qkv_proj"]) }
315+
}
316+
}

mlx-swift-examples.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; };
1212
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
1313
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
14+
7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */; };
1415
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
1516
819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; };
1617
C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; };
@@ -220,6 +221,7 @@
220221
1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = "<group>"; };
221222
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
222223
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
224+
7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpenELM.swift; sourceTree = "<group>"; };
223225
819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = "<group>"; };
224226
C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = "<group>"; };
225227
C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = "<group>"; };
@@ -470,6 +472,7 @@
470472
C38935C62B869C7A0037B833 /* LLM */ = {
471473
isa = PBXGroup;
472474
children = (
475+
7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */,
473476
C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */,
474477
C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */,
475478
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */,
@@ -1006,6 +1009,7 @@
10061009
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
10071010
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */,
10081011
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */,
1012+
7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */,
10091013
C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
10101014
C38935CE2B869C870037B833 /* Load.swift in Sources */,
10111015
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,

0 commit comments

Comments
 (0)