Skip to content

Commit ac6bdfc

Browse files
DePasqualeOrgawni
andauthored
Add Llama 3.1 (#98)
* Update Mistral 7B config * Add Mistral NeMo * Update for Llama 3.1 * Align LlamaConfiguration with Python implementation * Fix model configuration names * Refine DynamicNTKScalingRoPE * compute base only once --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent c4fda0e commit ac6bdfc

File tree

3 files changed

+202
-86
lines changed

3 files changed

+202
-86
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class LLMEvaluator {
159159

160160
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
161161
/// more devices
162-
let modelConfiguration = ModelConfiguration.phi34bit
162+
let modelConfiguration = ModelConfiguration.phi3_4bit
163163

164164
/// parameters controlling the output
165165
let generateParameters = GenerateParameters(temperature: 0.6)

Libraries/LLM/Llama.swift

Lines changed: 188 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,86 @@ import MLXNN
77

88
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
99

10+
func computeBaseFrequency(
11+
base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]?
12+
)
13+
-> Float
14+
{
15+
if ropeType != "llama3" {
16+
return base
17+
}
18+
19+
guard let ropeScaling = ropeScaling else {
20+
return base
21+
}
22+
23+
guard case .float(let factor) = ropeScaling["factor"],
24+
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
25+
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
26+
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
27+
?? .float(8192)
28+
else {
29+
return base
30+
}
31+
32+
let lowFreqWavelen = oldContextLen / lowFreqFactor
33+
let highFreqWavelen = oldContextLen / highFreqFactor
34+
35+
let freqs = (0 ..< dims).compactMap { index -> Float? in
36+
if index % 2 == 0 {
37+
return pow(base, Float(index) / Float(dims))
38+
}
39+
return nil
40+
}
41+
42+
let newBaseFreqs = freqs.map { freq -> Float in
43+
let wavelen = 2 * .pi / freq
44+
let smooth = max(
45+
0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen)))
46+
return freq * ((1 - smooth) * factor + smooth)
47+
}
48+
49+
return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count)
50+
}
51+
52+
private class DynamicNTKScalingRoPE: Module {
53+
let dims: Int
54+
let maxPositionEmbeddings: Int?
55+
let traditional: Bool
56+
let base: Float
57+
var scale: Float
58+
let ropeType: String
59+
let ropeScaling: [String: StringOrNumber]?
60+
61+
init(
62+
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
63+
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
64+
ropeScaling: [String: StringOrNumber]? = nil
65+
) {
66+
self.dims = dims
67+
self.maxPositionEmbeddings = maxPositionEmbeddings
68+
self.traditional = traditional
69+
self.base = computeBaseFrequency(
70+
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
71+
self.scale = scale
72+
self.ropeType = ropeType
73+
self.ropeScaling = ropeScaling
74+
}
75+
76+
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
77+
let seqLen = x.dim(1) + offset
78+
var base = self.base
79+
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
80+
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
81+
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
82+
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
83+
base *= adjustedScale
84+
}
85+
return MLXFast.RoPE(
86+
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
87+
}
88+
}
89+
1090
private class Attention: Module {
1191

1292
let args: LlamaConfiguration
@@ -17,9 +97,9 @@ private class Attention: Module {
1797
@ModuleInfo(key: "v_proj") var wv: Linear
1898
@ModuleInfo(key: "o_proj") var wo: Linear
1999

20-
let rope: RoPE
100+
let rope: DynamicNTKScalingRoPE
21101

22-
public init(_ args: LlamaConfiguration) {
102+
init(_ args: LlamaConfiguration) {
23103
self.args = args
24104

25105
let dim = args.hiddenSize
@@ -29,31 +109,28 @@ private class Attention: Module {
29109
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
30110
self.scale = pow(Float(headDim), -0.5)
31111

32-
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
33-
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
34-
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
35-
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
36-
37-
let ropeScale: Float
38-
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
39-
let factor = ropeScaling["factor"]
40-
{
41-
switch factor {
42-
case .string:
43-
fatalError("ropeScaling.factor must be a float")
44-
case .float(let v):
45-
ropeScale = 1 / v
46-
}
47-
} else {
48-
ropeScale = 1
49-
}
50-
51-
self.rope = RoPE(
52-
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
53-
scale: ropeScale)
112+
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
113+
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
114+
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
115+
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
116+
117+
self.rope = DynamicNTKScalingRoPE(
118+
dims: headDim,
119+
maxPositionEmbeddings: args.maxPositionEmbeddings,
120+
traditional: args.ropeTraditional,
121+
base: args.ropeTheta,
122+
scale: 1.0,
123+
ropeType: {
124+
if case .string(let value) = args.ropeScaling?["type"] {
125+
return value
126+
} else {
127+
return "default"
128+
}
129+
}(),
130+
ropeScaling: args.ropeScaling)
54131
}
55132

56-
public func callAsFunction(
133+
func callAsFunction(
57134
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
58135
) -> (MLXArray, (MLXArray, MLXArray)) {
59136
let (B, L) = (x.dim(0), x.dim(1))
@@ -62,7 +139,7 @@ private class Attention: Module {
62139
var keys = wk(x)
63140
var values = wv(x)
64141

65-
// prepare the queries, keys and values for the attention computation
142+
// Prepare the queries, keys and values for the attention computation
66143
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
67144
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
68145
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
@@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer {
93170
@ModuleInfo(key: "down_proj") var down: Linear
94171
@ModuleInfo(key: "up_proj") var up: Linear
95172

96-
public init(dimensions: Int, hiddenDimensions: Int) {
97-
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
98-
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
99-
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
173+
init(_ args: LlamaConfiguration) {
174+
self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
175+
self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
176+
self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
100177
}
101178

102-
public func callAsFunction(_ x: MLXArray) -> MLXArray {
103-
down(silu(gate(x)) * up(x))
179+
func callAsFunction(_ x: MLXArray) -> MLXArray {
180+
let activation = silu(gate(x))
181+
return down(activation * up(x))
104182
}
105183
}
106184

107185
private class TransformerBlock: Module {
108-
109186
@ModuleInfo(key: "self_attn") var attention: Attention
110-
let mlp: MLP
187+
@ModuleInfo(key: "mlp") var mlp: MLP
111188

112189
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
113190
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
114191

115-
public init(_ args: LlamaConfiguration) {
192+
init(_ args: LlamaConfiguration) {
116193
self._attention.wrappedValue = Attention(args)
117-
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
194+
self._mlp.wrappedValue = MLP(args)
118195
self._inputLayerNorm.wrappedValue = RMSNorm(
119196
dimensions: args.hiddenSize, eps: args.rmsNormEps)
120197
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
121198
dimensions: args.hiddenSize, eps: args.rmsNormEps)
122199
}
123200

124-
public func callAsFunction(
201+
func callAsFunction(
125202
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
126203
) -> (MLXArray, (MLXArray, MLXArray)) {
127204
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
@@ -132,27 +209,24 @@ private class TransformerBlock: Module {
132209
}
133210
}
134211

135-
public class LlamaModelInner: Module {
212+
private class LlamaModelInner: Module {
136213

137214
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
138215

139-
fileprivate let layers: [TransformerBlock]
216+
let layers: [TransformerBlock]
140217
let norm: RMSNorm
141218

142-
public init(_ args: LlamaConfiguration) {
219+
init(_ args: LlamaConfiguration) {
143220
precondition(args.vocabularySize > 0)
144221

145222
self._embedTokens.wrappedValue = Embedding(
146223
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
147224

148-
self.layers = (0 ..< args.hiddenLayers)
149-
.map { _ in
150-
TransformerBlock(args)
151-
}
225+
self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
152226
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
153227
}
154228

155-
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
229+
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
156230
MLXArray, [(MLXArray, MLXArray)]
157231
) {
158232
var h = embedTokens(inputs)
@@ -178,7 +252,7 @@ public class LlamaModelInner: Module {
178252
public class LlamaModel: Module, LLMModel {
179253

180254
public let vocabularySize: Int
181-
let model: LlamaModelInner
255+
fileprivate let model: LlamaModelInner
182256

183257
@ModuleInfo(key: "lm_head") var lmHead: Linear?
184258

@@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel {
202276
}
203277

204278
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
205-
// Remove unused precomputed rotary freqs
279+
// Remove unused precomputed rotary frequencies
206280
weights.filter {
207281
!$0.key.contains("self_attn.rotary_emb.inv_freq")
208282
}
@@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable {
215289
var hiddenLayers: Int
216290
var intermediateSize: Int
217291
var attentionHeads: Int
218-
var headDimensions: Int? = nil
292+
var headDimensions: Int?
219293
var rmsNormEps: Float
220294
var vocabularySize: Int
221295
var kvHeads: Int
296+
var maxPositionEmbeddings: Int?
222297
var ropeTheta: Float = 10_000
223298
var ropeTraditional: Bool = false
224-
var ropeScaling: [String: StringOrNumber]? = nil
225-
var tieWordEmbeddings: Bool = false
299+
var ropeScaling: [String: StringOrNumber]?
300+
var tieWordEmbeddings: Bool = true
301+
var attentionBias: Bool = false
302+
var mlpBias: Bool = false
226303

227304
enum CodingKeys: String, CodingKey {
228305
case hiddenSize = "hidden_size"
@@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable {
233310
case rmsNormEps = "rms_norm_eps"
234311
case vocabularySize = "vocab_size"
235312
case kvHeads = "num_key_value_heads"
313+
case maxPositionEmbeddings = "max_position_embeddings"
236314
case ropeTheta = "rope_theta"
237315
case ropeTraditional = "rope_traditional"
238316
case ropeScaling = "rope_scaling"
239317
case tieWordEmbeddings = "tie_word_embeddings"
318+
case attentionBias = "attention_bias"
319+
case mlpBias = "mlp_bias"
240320
}
241321

242322
public init(from decoder: Decoder) throws {
243-
// custom implementation to handle optional keys with required values
244-
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
245-
try decoder.container(
246-
keyedBy: LlamaConfiguration.CodingKeys.self)
247-
248-
self.hiddenSize = try container.decode(
249-
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
250-
self.hiddenLayers = try container.decode(
251-
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
252-
self.intermediateSize = try container.decode(
253-
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
254-
self.attentionHeads = try container.decode(
255-
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
256-
self.headDimensions = try container.decodeIfPresent(
257-
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
258-
self.rmsNormEps = try container.decode(
259-
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
260-
self.vocabularySize = try container.decode(
261-
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
262-
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
263-
self.ropeTheta =
264-
try container.decodeIfPresent(
265-
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
266-
?? 10_000
267-
self.ropeTraditional =
268-
try container.decodeIfPresent(
269-
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
270-
self.ropeScaling = try container.decodeIfPresent(
271-
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
272-
self.tieWordEmbeddings =
273-
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
323+
let container = try decoder.container(keyedBy: CodingKeys.self)
324+
325+
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
326+
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
327+
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
328+
attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
329+
headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
330+
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
331+
vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
332+
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
333+
maxPositionEmbeddings = try container.decodeIfPresent(
334+
Int.self, forKey: .maxPositionEmbeddings)
335+
if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) {
336+
self.ropeTheta = ropeTheta
337+
}
338+
if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional)
339+
{
340+
self.ropeTraditional = ropeTraditional
341+
}
342+
ropeScaling = try container.decodeIfPresent(
343+
[String: StringOrNumber].self, forKey: .ropeScaling)
344+
if let tieWordEmbeddings = try container.decodeIfPresent(
345+
Bool.self, forKey: .tieWordEmbeddings)
346+
{
347+
self.tieWordEmbeddings = tieWordEmbeddings
348+
}
349+
if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) {
350+
self.attentionBias = attentionBias
351+
}
352+
if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) {
353+
self.mlpBias = mlpBias
354+
}
274355

356+
if let ropeScaling {
357+
if ropeScaling["factor"] == nil {
358+
throw DecodingError.dataCorruptedError(
359+
forKey: .ropeScaling, in: container,
360+
debugDescription: "rope_scaling must contain 'factor'")
361+
}
362+
if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] {
363+
if case .string = ropeType {
364+
let options = [
365+
StringOrNumber.string("linear"), StringOrNumber.string("dynamic"),
366+
StringOrNumber.string("llama3"),
367+
]
368+
if !options.contains(ropeType) {
369+
throw DecodingError.dataCorruptedError(
370+
forKey: .ropeScaling, in: container,
371+
debugDescription:
372+
"rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'"
373+
)
374+
}
375+
}
376+
} else {
377+
throw DecodingError.dataCorruptedError(
378+
forKey: .ropeScaling, in: container,
379+
debugDescription: "rope_scaling must contain either 'type' or 'rope_type'")
380+
}
381+
}
275382
}
276383
}
277384

0 commit comments

Comments
 (0)