@@ -262,111 +262,3 @@ extension GemmaModel: LoRAModel {
262262 model. layers. map { ( $0. attention, [ " q_proj " , " v_proj " ] ) }
263263 }
264264}
265-
266- // Gemma 2
267-
268- // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
269-
270- // Minimal changes from Gemma TransformerBlock
271- private class Gemma2TransformerBlock : Module {
272-
273- @ModuleInfo ( key: " self_attn " ) var attention : Attention
274- let mlp : MLP
275-
276- @ModuleInfo ( key: " input_layernorm " ) var inputLayerNorm : RMSNorm
277- @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayerNorm : RMSNorm
278- @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayerNorm : RMSNorm
279- @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayerNorm : RMSNorm
280-
281- public init ( _ args: GemmaConfiguration ) {
282- self . _attention. wrappedValue = Attention ( args)
283- self . mlp = MLP ( dimensions: args. hiddenSize, hiddenDimensions: args. intermediateSize)
284- self . _inputLayerNorm. wrappedValue = RMSNorm (
285- dimensions: args. hiddenSize, eps: args. rmsNormEps)
286- self . _preFeedforwardLayerNorm. wrappedValue = RMSNorm (
287- dimensions: args. hiddenSize, eps: args. rmsNormEps)
288- self . _postFeedforwardLayerNorm. wrappedValue = RMSNorm (
289- dimensions: args. hiddenSize, eps: args. rmsNormEps)
290- self . _postAttentionLayerNorm. wrappedValue = RMSNorm (
291- dimensions: args. hiddenSize, eps: args. rmsNormEps)
292- }
293-
294- public func callAsFunction(
295- _ x: MLXArray , mask: MLXArray ? = nil , cache: ( MLXArray , MLXArray ) ? = nil
296- ) -> ( MLXArray , ( MLXArray , MLXArray ) ) {
297- var ( r, cache) = attention ( inputLayerNorm ( x) , mask: mask, cache: cache)
298- let h = x + postAttentionLayerNorm( r)
299- r = mlp ( preFeedforwardLayerNorm ( h) )
300- let out = h + postFeedforwardLayerNorm( r)
301- return ( out, cache)
302- }
303- }
304-
305- // Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
306- public class Gemma2ModelInner : Module {
307-
308- @ModuleInfo ( key: " embed_tokens " ) var embedTokens : Embedding
309-
310- fileprivate let layers : [ Gemma2TransformerBlock ]
311- fileprivate let norm : RMSNorm
312-
313- let hiddenScale : Float
314-
315- public init ( _ args: GemmaConfiguration ) {
316- precondition ( args. vocabularySize > 0 )
317-
318- self . _embedTokens. wrappedValue = Embedding (
319- embeddingCount: args. vocabularySize, dimensions: args. hiddenSize)
320-
321- self . hiddenScale = pow ( Float ( args. hiddenSize) , 0.5 )
322-
323- self . layers = ( 0 ..< args. hiddenLayers)
324- . map { _ in
325- Gemma2TransformerBlock ( args)
326- }
327- self . norm = RMSNorm ( dimensions: args. hiddenSize, eps: args. rmsNormEps)
328- }
329-
330- public func callAsFunction( _ inputs: MLXArray , cache: [ ( MLXArray , MLXArray ) ] ? = nil ) -> (
331- MLXArray , [ ( MLXArray , MLXArray ) ]
332- ) {
333- var h = embedTokens ( inputs)
334- h = h * hiddenScale
335-
336- var mask : MLXArray ? = nil
337- if h. dim ( 1 ) > 1 {
338- mask = MultiHeadAttention . createAdditiveCausalMask ( h. dim ( 1 ) )
339- mask = mask? . asType ( h. dtype)
340- }
341-
342- var newCache = [ ( MLXArray, MLXArray) ] ( )
343-
344- for (i, layer) in layers. enumerated ( ) {
345- var cacheUpdate : ( MLXArray , MLXArray )
346- ( h, cacheUpdate) = layer ( h, mask: mask, cache: cache ? [ i] )
347- newCache. append ( cacheUpdate)
348- }
349-
350- return ( norm ( h) , newCache)
351- }
352- }
353-
354- // Uses Gemma2ModelInner, otherwise same as GemmaModel
355- public class Gemma2Model : Module , LLMModel {
356-
357- public let vocabularySize : Int
358- let model : Gemma2ModelInner
359-
360- public init ( _ args: GemmaConfiguration ) {
361- self . vocabularySize = args. vocabularySize
362- self . model = Gemma2ModelInner ( args)
363- }
364-
365- public func callAsFunction( _ inputs: MLXArray , cache: [ ( MLXArray , MLXArray ) ] ? ) -> (
366- MLXArray , [ ( MLXArray , MLXArray ) ]
367- ) {
368- var ( out, cache) = model ( inputs, cache: cache)
369- out = model. embedTokens. asLinear ( out)
370- return ( out, cache)
371- }
372- }
0 commit comments