Skip to content

Commit 885e520

Browse files
authored
Some fixes for gemma2 (#99)
* some fixes for gemma2 * format * fixes * format
1 parent ac6bdfc commit 885e520

File tree

5 files changed

+326
-110
lines changed

5 files changed

+326
-110
lines changed

Libraries/LLM/Configuration.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public enum ModelType: String, Codable {
5858
return GemmaModel(configuration)
5959
case .gemma2:
6060
let configuration = try JSONDecoder().decode(
61-
GemmaConfiguration.self, from: Data(contentsOf: configuration))
61+
Gemma2Configuration.self, from: Data(contentsOf: configuration))
6262
return Gemma2Model(configuration)
6363
case .qwen2:
6464
let configuration = try JSONDecoder().decode(

Libraries/LLM/Gemma.swift

Lines changed: 0 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -262,111 +262,3 @@ extension GemmaModel: LoRAModel {
262262
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
263263
}
264264
}
265-
266-
// Gemma 2
267-
268-
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
269-
270-
// Minimal changes from Gemma TransformerBlock
271-
private class Gemma2TransformerBlock: Module {
272-
273-
@ModuleInfo(key: "self_attn") var attention: Attention
274-
let mlp: MLP
275-
276-
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
277-
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
278-
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
279-
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
280-
281-
public init(_ args: GemmaConfiguration) {
282-
self._attention.wrappedValue = Attention(args)
283-
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
284-
self._inputLayerNorm.wrappedValue = RMSNorm(
285-
dimensions: args.hiddenSize, eps: args.rmsNormEps)
286-
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
287-
dimensions: args.hiddenSize, eps: args.rmsNormEps)
288-
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
289-
dimensions: args.hiddenSize, eps: args.rmsNormEps)
290-
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
291-
dimensions: args.hiddenSize, eps: args.rmsNormEps)
292-
}
293-
294-
public func callAsFunction(
295-
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
296-
) -> (MLXArray, (MLXArray, MLXArray)) {
297-
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
298-
let h = x + postAttentionLayerNorm(r)
299-
r = mlp(preFeedforwardLayerNorm(h))
300-
let out = h + postFeedforwardLayerNorm(r)
301-
return (out, cache)
302-
}
303-
}
304-
305-
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
306-
public class Gemma2ModelInner: Module {
307-
308-
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
309-
310-
fileprivate let layers: [Gemma2TransformerBlock]
311-
fileprivate let norm: RMSNorm
312-
313-
let hiddenScale: Float
314-
315-
public init(_ args: GemmaConfiguration) {
316-
precondition(args.vocabularySize > 0)
317-
318-
self._embedTokens.wrappedValue = Embedding(
319-
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
320-
321-
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
322-
323-
self.layers = (0 ..< args.hiddenLayers)
324-
.map { _ in
325-
Gemma2TransformerBlock(args)
326-
}
327-
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
328-
}
329-
330-
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
331-
MLXArray, [(MLXArray, MLXArray)]
332-
) {
333-
var h = embedTokens(inputs)
334-
h = h * hiddenScale
335-
336-
var mask: MLXArray? = nil
337-
if h.dim(1) > 1 {
338-
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
339-
mask = mask?.asType(h.dtype)
340-
}
341-
342-
var newCache = [(MLXArray, MLXArray)]()
343-
344-
for (i, layer) in layers.enumerated() {
345-
var cacheUpdate: (MLXArray, MLXArray)
346-
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
347-
newCache.append(cacheUpdate)
348-
}
349-
350-
return (norm(h), newCache)
351-
}
352-
}
353-
354-
// Uses Gemma2ModelInner, otherwise same as GemmaModel
355-
public class Gemma2Model: Module, LLMModel {
356-
357-
public let vocabularySize: Int
358-
let model: Gemma2ModelInner
359-
360-
public init(_ args: GemmaConfiguration) {
361-
self.vocabularySize = args.vocabularySize
362-
self.model = Gemma2ModelInner(args)
363-
}
364-
365-
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
366-
MLXArray, [(MLXArray, MLXArray)]
367-
) {
368-
var (out, cache) = model(inputs, cache: cache)
369-
out = model.embedTokens.asLinear(out)
370-
return (out, cache)
371-
}
372-
}

0 commit comments

Comments
 (0)