@@ -7,6 +7,86 @@ import MLXNN
77
88// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
99
10+ func computeBaseFrequency(
11+ base: Float , dims: Int , ropeType: String , ropeScaling: [ String : StringOrNumber ] ?
12+ )
13+ -> Float
14+ {
15+ if ropeType != " llama3 " {
16+ return base
17+ }
18+
19+ guard let ropeScaling = ropeScaling else {
20+ return base
21+ }
22+
23+ guard case . float( let factor) = ropeScaling [ " factor " ] ,
24+ case . float( let lowFreqFactor) = ropeScaling [ " low_freq_factor " ] ?? . float( 1.0 ) ,
25+ case . float( let highFreqFactor) = ropeScaling [ " high_freq_factor " ] ?? . float( 4.0 ) ,
26+ case . float( let oldContextLen) = ropeScaling [ " original_max_position_embeddings " ]
27+ ?? . float( 8192 )
28+ else {
29+ return base
30+ }
31+
32+ let lowFreqWavelen = oldContextLen / lowFreqFactor
33+ let highFreqWavelen = oldContextLen / highFreqFactor
34+
35+ let freqs = ( 0 ..< dims) . compactMap { index -> Float ? in
36+ if index % 2 == 0 {
37+ return pow ( base, Float ( index) / Float( dims) )
38+ }
39+ return nil
40+ }
41+
42+ let newBaseFreqs = freqs. map { freq -> Float in
43+ let wavelen = 2 * . pi / freq
44+ let smooth = max (
45+ 0 , min ( 1 , ( wavelen - highFreqWavelen) / ( lowFreqWavelen - highFreqWavelen) ) )
46+ return freq * ( ( 1 - smooth) * factor + smooth)
47+ }
48+
49+ return newBaseFreqs. reduce ( 0 , + ) / Float( newBaseFreqs. count)
50+ }
51+
52+ private class DynamicNTKScalingRoPE : Module {
53+ let dims : Int
54+ let maxPositionEmbeddings : Int ?
55+ let traditional : Bool
56+ let base : Float
57+ var scale : Float
58+ let ropeType : String
59+ let ropeScaling : [ String : StringOrNumber ] ?
60+
61+ init (
62+ dims: Int , maxPositionEmbeddings: Int ? , traditional: Bool = false ,
63+ base: Float = 10000 , scale: Float = 1.0 , ropeType: String = " default " ,
64+ ropeScaling: [ String : StringOrNumber ] ? = nil
65+ ) {
66+ self . dims = dims
67+ self . maxPositionEmbeddings = maxPositionEmbeddings
68+ self . traditional = traditional
69+ self . base = computeBaseFrequency (
70+ base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
71+ self . scale = scale
72+ self . ropeType = ropeType
73+ self . ropeScaling = ropeScaling
74+ }
75+
76+ func callAsFunction( _ x: MLXArray , offset: Int = 0 ) -> MLXArray {
77+ let seqLen = x. dim ( 1 ) + offset
78+ var base = self . base
79+ if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
80+ let factorAdjustment = Float ( seqLen) / Float( maxPositionEmbeddings) - 1
81+ let dimensionRatio = Float ( dims) / Float( Float ( dims) - 2 )
82+ let adjustedScale = scale * pow( 1 + factorAdjustment, dimensionRatio)
83+ base *= adjustedScale
84+ }
85+ return MLXFast . RoPE (
86+ x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
87+ }
88+ }
89+
1090private class Attention : Module {
1191
1292 let args : LlamaConfiguration
@@ -17,9 +97,9 @@ private class Attention: Module {
1797 @ModuleInfo ( key: " v_proj " ) var wv : Linear
1898 @ModuleInfo ( key: " o_proj " ) var wo : Linear
1999
20- let rope : RoPE
100+ let rope : DynamicNTKScalingRoPE
21101
22- public init ( _ args: LlamaConfiguration ) {
102+ init ( _ args: LlamaConfiguration ) {
23103 self . args = args
24104
25105 let dim = args. hiddenSize
@@ -29,31 +109,28 @@ private class Attention: Module {
29109 let headDim = args. headDimensions ?? ( args. hiddenSize / heads)
30110 self . scale = pow ( Float ( headDim) , - 0.5 )
31111
32- self . _wq. wrappedValue = Linear ( dim, heads * headDim, bias: false )
33- self . _wk. wrappedValue = Linear ( dim, kvHeads * headDim, bias: false )
34- self . _wv. wrappedValue = Linear ( dim, kvHeads * headDim, bias: false )
35- self . _wo. wrappedValue = Linear ( heads * headDim, dim, bias: false )
36-
37- let ropeScale : Float
38- if let ropeScaling = args. ropeScaling, ropeScaling [ " type " ] == . string( " linear " ) ,
39- let factor = ropeScaling [ " factor " ]
40- {
41- switch factor {
42- case . string:
43- fatalError ( " ropeScaling.factor must be a float " )
44- case . float( let v) :
45- ropeScale = 1 / v
46- }
47- } else {
48- ropeScale = 1
49- }
50-
51- self . rope = RoPE (
52- dimensions: headDim, traditional: args. ropeTraditional, base: args. ropeTheta,
53- scale: ropeScale)
112+ self . _wq. wrappedValue = Linear ( dim, heads * headDim, bias: args. attentionBias)
113+ self . _wk. wrappedValue = Linear ( dim, kvHeads * headDim, bias: args. attentionBias)
114+ self . _wv. wrappedValue = Linear ( dim, kvHeads * headDim, bias: args. attentionBias)
115+ self . _wo. wrappedValue = Linear ( heads * headDim, dim, bias: args. attentionBias)
116+
117+ self . rope = DynamicNTKScalingRoPE (
118+ dims: headDim,
119+ maxPositionEmbeddings: args. maxPositionEmbeddings,
120+ traditional: args. ropeTraditional,
121+ base: args. ropeTheta,
122+ scale: 1.0 ,
123+ ropeType: {
124+ if case . string( let value) = args. ropeScaling ? [ " type " ] {
125+ return value
126+ } else {
127+ return " default "
128+ }
129+ } ( ) ,
130+ ropeScaling: args. ropeScaling)
54131 }
55132
56- public func callAsFunction(
133+ func callAsFunction(
57134 _ x: MLXArray , mask: MLXArray ? = nil , cache: ( MLXArray , MLXArray ) ? = nil
58135 ) -> ( MLXArray , ( MLXArray , MLXArray ) ) {
59136 let ( B, L) = ( x. dim ( 0 ) , x. dim ( 1 ) )
@@ -62,7 +139,7 @@ private class Attention: Module {
62139 var keys = wk ( x)
63140 var values = wv ( x)
64141
65- // prepare the queries, keys and values for the attention computation
142+ // Prepare the queries, keys and values for the attention computation
66143 queries = queries. reshaped ( B, L, args. attentionHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
67144 keys = keys. reshaped ( B, L, args. kvHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
68145 values = values. reshaped ( B, L, args. kvHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
@@ -93,35 +170,35 @@ private class MLP: Module, UnaryLayer {
93170 @ModuleInfo ( key: " down_proj " ) var down : Linear
94171 @ModuleInfo ( key: " up_proj " ) var up : Linear
95172
96- public init ( dimensions : Int , hiddenDimensions : Int ) {
97- self . _gate. wrappedValue = Linear ( dimensions , hiddenDimensions , bias: false )
98- self . _down. wrappedValue = Linear ( hiddenDimensions , dimensions , bias: false )
99- self . _up. wrappedValue = Linear ( dimensions , hiddenDimensions , bias: false )
173+ init ( _ args : LlamaConfiguration ) {
174+ self . _gate. wrappedValue = Linear ( args . hiddenSize , args . intermediateSize , bias: args . mlpBias )
175+ self . _down. wrappedValue = Linear ( args . intermediateSize , args . hiddenSize , bias: args . mlpBias )
176+ self . _up. wrappedValue = Linear ( args . hiddenSize , args . intermediateSize , bias: args . mlpBias )
100177 }
101178
102- public func callAsFunction( _ x: MLXArray ) -> MLXArray {
103- down ( silu ( gate ( x) ) * up( x) )
179+ func callAsFunction( _ x: MLXArray ) -> MLXArray {
180+ let activation = silu ( gate ( x) )
181+ return down ( activation * up( x) )
104182 }
105183}
106184
107185private class TransformerBlock : Module {
108-
109186 @ModuleInfo ( key: " self_attn " ) var attention : Attention
110- let mlp : MLP
187+ @ ModuleInfo ( key : " mlp " ) var mlp : MLP
111188
112189 @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : RMSNorm
113190 @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : RMSNorm
114191
115- public init ( _ args: LlamaConfiguration ) {
192+ init ( _ args: LlamaConfiguration ) {
116193 self . _attention. wrappedValue = Attention ( args)
117- self . mlp = MLP ( dimensions : args. hiddenSize , hiddenDimensions : args . intermediateSize )
194+ self . _mlp . wrappedValue = MLP ( args)
118195 self . _inputLayerNorm. wrappedValue = RMSNorm (
119196 dimensions: args. hiddenSize, eps: args. rmsNormEps)
120197 self . _postAttentionLayerNorm. wrappedValue = RMSNorm (
121198 dimensions: args. hiddenSize, eps: args. rmsNormEps)
122199 }
123200
124- public func callAsFunction(
201+ func callAsFunction(
125202 _ x: MLXArray , mask: MLXArray ? = nil , cache: ( MLXArray , MLXArray ) ? = nil
126203 ) -> ( MLXArray , ( MLXArray , MLXArray ) ) {
127204 var ( r, cache) = attention ( inputLayerNorm ( x) , mask: mask, cache: cache)
@@ -132,27 +209,24 @@ private class TransformerBlock: Module {
132209 }
133210}
134211
135- public class LlamaModelInner : Module {
212+ private class LlamaModelInner : Module {
136213
137214 @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
138215
139- fileprivate let layers : [ TransformerBlock ]
216+ let layers : [ TransformerBlock ]
140217 let norm : RMSNorm
141218
142- public init ( _ args: LlamaConfiguration ) {
219+ init ( _ args: LlamaConfiguration ) {
143220 precondition ( args. vocabularySize > 0 )
144221
145222 self . _embedTokens. wrappedValue = Embedding (
146223 embeddingCount: args. vocabularySize, dimensions: args. hiddenSize)
147224
148- self . layers = ( 0 ..< args. hiddenLayers)
149- . map { _ in
150- TransformerBlock ( args)
151- }
225+ self . layers = ( 0 ..< args. hiddenLayers) . map { _ in TransformerBlock ( args) }
152226 self . norm = RMSNorm ( dimensions: args. hiddenSize, eps: args. rmsNormEps)
153227 }
154228
155- public func callAsFunction( _ inputs: MLXArray , cache: [ ( MLXArray , MLXArray ) ] ? = nil ) -> (
229+ func callAsFunction( _ inputs: MLXArray , cache: [ ( MLXArray , MLXArray ) ] ? = nil ) -> (
156230 MLXArray , [ ( MLXArray , MLXArray ) ]
157231 ) {
158232 var h = embedTokens ( inputs)
@@ -178,7 +252,7 @@ public class LlamaModelInner: Module {
178252public class LlamaModel : Module , LLMModel {
179253
180254 public let vocabularySize : Int
181- let model : LlamaModelInner
255+ fileprivate let model : LlamaModelInner
182256
183257 @ModuleInfo ( key: " lm_head " ) var lmHead : Linear ?
184258
@@ -202,7 +276,7 @@ public class LlamaModel: Module, LLMModel {
202276 }
203277
204278 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
205- // Remove unused precomputed rotary freqs
279+ // Remove unused precomputed rotary frequencies
206280 weights. filter {
207281 !$0. key. contains ( " self_attn.rotary_emb.inv_freq " )
208282 }
@@ -215,14 +289,17 @@ public struct LlamaConfiguration: Codable {
215289 var hiddenLayers : Int
216290 var intermediateSize : Int
217291 var attentionHeads : Int
218- var headDimensions : Int ? = nil
292+ var headDimensions : Int ?
219293 var rmsNormEps : Float
220294 var vocabularySize : Int
221295 var kvHeads : Int
296+ var maxPositionEmbeddings : Int ?
222297 var ropeTheta : Float = 10_000
223298 var ropeTraditional : Bool = false
224- var ropeScaling : [ String : StringOrNumber ] ? = nil
225- var tieWordEmbeddings : Bool = false
299+ var ropeScaling : [ String : StringOrNumber ] ?
300+ var tieWordEmbeddings : Bool = true
301+ var attentionBias : Bool = false
302+ var mlpBias : Bool = false
226303
227304 enum CodingKeys : String , CodingKey {
228305 case hiddenSize = " hidden_size "
@@ -233,45 +310,75 @@ public struct LlamaConfiguration: Codable {
233310 case rmsNormEps = " rms_norm_eps "
234311 case vocabularySize = " vocab_size "
235312 case kvHeads = " num_key_value_heads "
313+ case maxPositionEmbeddings = " max_position_embeddings "
236314 case ropeTheta = " rope_theta "
237315 case ropeTraditional = " rope_traditional "
238316 case ropeScaling = " rope_scaling "
239317 case tieWordEmbeddings = " tie_word_embeddings "
318+ case attentionBias = " attention_bias "
319+ case mlpBias = " mlp_bias "
240320 }
241321
242322 public init ( from decoder: Decoder ) throws {
243- // custom implementation to handle optional keys with required values
244- let container : KeyedDecodingContainer < LlamaConfiguration . CodingKeys > =
245- try decoder. container (
246- keyedBy: LlamaConfiguration . CodingKeys. self)
247-
248- self . hiddenSize = try container. decode (
249- Int . self, forKey: LlamaConfiguration . CodingKeys. hiddenSize)
250- self . hiddenLayers = try container. decode (
251- Int . self, forKey: LlamaConfiguration . CodingKeys. hiddenLayers)
252- self . intermediateSize = try container. decode (
253- Int . self, forKey: LlamaConfiguration . CodingKeys. intermediateSize)
254- self . attentionHeads = try container. decode (
255- Int . self, forKey: LlamaConfiguration . CodingKeys. attentionHeads)
256- self . headDimensions = try container. decodeIfPresent (
257- Int . self, forKey: LlamaConfiguration . CodingKeys. headDimensions)
258- self . rmsNormEps = try container. decode (
259- Float . self, forKey: LlamaConfiguration . CodingKeys. rmsNormEps)
260- self . vocabularySize = try container. decode (
261- Int . self, forKey: LlamaConfiguration . CodingKeys. vocabularySize)
262- self . kvHeads = try container. decode ( Int . self, forKey: LlamaConfiguration . CodingKeys. kvHeads)
263- self . ropeTheta =
264- try container. decodeIfPresent (
265- Float . self, forKey: LlamaConfiguration . CodingKeys. ropeTheta)
266- ?? 10_000
267- self . ropeTraditional =
268- try container. decodeIfPresent (
269- Bool . self, forKey: LlamaConfiguration . CodingKeys. ropeTraditional) ?? false
270- self . ropeScaling = try container. decodeIfPresent (
271- [ String : StringOrNumber ] . self, forKey: LlamaConfiguration . CodingKeys. ropeScaling)
272- self . tieWordEmbeddings =
273- try container. decodeIfPresent ( Bool . self, forKey: . tieWordEmbeddings) ?? false
323+ let container = try decoder. container ( keyedBy: CodingKeys . self)
324+
325+ hiddenSize = try container. decode ( Int . self, forKey: . hiddenSize)
326+ hiddenLayers = try container. decode ( Int . self, forKey: . hiddenLayers)
327+ intermediateSize = try container. decode ( Int . self, forKey: . intermediateSize)
328+ attentionHeads = try container. decode ( Int . self, forKey: . attentionHeads)
329+ headDimensions = try container. decodeIfPresent ( Int . self, forKey: . headDimensions)
330+ rmsNormEps = try container. decode ( Float . self, forKey: . rmsNormEps)
331+ vocabularySize = try container. decode ( Int . self, forKey: . vocabularySize)
332+ kvHeads = try container. decodeIfPresent ( Int . self, forKey: . kvHeads) ?? attentionHeads
333+ maxPositionEmbeddings = try container. decodeIfPresent (
334+ Int . self, forKey: . maxPositionEmbeddings)
335+ if let ropeTheta = try container. decodeIfPresent ( Float . self, forKey: . ropeTheta) {
336+ self . ropeTheta = ropeTheta
337+ }
338+ if let ropeTraditional = try container. decodeIfPresent ( Bool . self, forKey: . ropeTraditional)
339+ {
340+ self . ropeTraditional = ropeTraditional
341+ }
342+ ropeScaling = try container. decodeIfPresent (
343+ [ String : StringOrNumber ] . self, forKey: . ropeScaling)
344+ if let tieWordEmbeddings = try container. decodeIfPresent (
345+ Bool . self, forKey: . tieWordEmbeddings)
346+ {
347+ self . tieWordEmbeddings = tieWordEmbeddings
348+ }
349+ if let attentionBias = try container. decodeIfPresent ( Bool . self, forKey: . attentionBias) {
350+ self . attentionBias = attentionBias
351+ }
352+ if let mlpBias = try container. decodeIfPresent ( Bool . self, forKey: . mlpBias) {
353+ self . mlpBias = mlpBias
354+ }
274355
356+ if let ropeScaling {
357+ if ropeScaling [ " factor " ] == nil {
358+ throw DecodingError . dataCorruptedError (
359+ forKey: . ropeScaling, in: container,
360+ debugDescription: " rope_scaling must contain 'factor' " )
361+ }
362+ if let ropeType = ropeScaling [ " type " ] ?? ropeScaling [ " rope_type " ] {
363+ if case . string = ropeType {
364+ let options = [
365+ StringOrNumber . string ( " linear " ) , StringOrNumber . string ( " dynamic " ) ,
366+ StringOrNumber . string ( " llama3 " ) ,
367+ ]
368+ if !options. contains ( ropeType) {
369+ throw DecodingError . dataCorruptedError (
370+ forKey: . ropeScaling, in: container,
371+ debugDescription:
372+ " rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3' "
373+ )
374+ }
375+ }
376+ } else {
377+ throw DecodingError . dataCorruptedError (
378+ forKey: . ropeScaling, in: container,
379+ debugDescription: " rope_scaling must contain either 'type' or 'rope_type' " )
380+ }
381+ }
275382 }
276383}
277384
0 commit comments