Skip to content

Commit 5d89cc9

Browse files
authored
fix(gemma3n): support per-layer intermediate_size array (#46)
* fix(gemma3n): support per-layer intermediate_size array Gemma 3n models from HuggingFace specify intermediate_size as an array (one value per layer) rather than a single integer. This causes a decoding error when trying to load these models. This commit introduces an IntOrArray type that can decode either format, maintaining backwards compatibility with models that use a single value while adding support for the per-layer array format. Fixes loading of models like: - mlx-community/gemma-3n-E2B-it-4bit - mlx-community/gemma-3n-E4B-it-4bit Tested with swift build - compiles successfully. * fix(gemma3n): make query_pre_attn_scalar optional Some HuggingFace Gemma 3n configs don't include this field. * fix(gemma3n): preserve all weights in sanitize function The sanitize function was only keeping weights with 'model.language_model.' prefix and discarding all others. This caused missing weight errors when loading Gemma 3n models.
1 parent 9c20e79 commit 5d89cc9

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

Libraries/MLXLLM/Models/Gemma3nText.swift

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,59 @@ import MLXNN
1515

1616
// MARK: - Configuration
1717

18+
/// A type that can be decoded as either a single Int or an array of Ints.
19+
/// This is needed because some models (like Gemma 3n) specify intermediate_size
20+
/// as a per-layer array, while others use a single value.
21+
public struct IntOrArray: Codable {
22+
public let values: [Int]
23+
24+
public init(from decoder: Decoder) throws {
25+
let container = try decoder.singleValueContainer()
26+
if let array = try? container.decode([Int].self) {
27+
self.values = array
28+
} else if let single = try? container.decode(Int.self) {
29+
self.values = [single]
30+
} else {
31+
throw DecodingError.typeMismatch(
32+
IntOrArray.self,
33+
DecodingError.Context(
34+
codingPath: decoder.codingPath,
35+
debugDescription: "Expected Int or [Int]"
36+
)
37+
)
38+
}
39+
}
40+
41+
public func encode(to encoder: Encoder) throws {
42+
var container = encoder.singleValueContainer()
43+
if values.count == 1 {
44+
try container.encode(values[0])
45+
} else {
46+
try container.encode(values)
47+
}
48+
}
49+
50+
/// Get the intermediate size for a specific layer
51+
public subscript(layerIdx: Int) -> Int {
52+
if values.count == 1 {
53+
return values[0]
54+
}
55+
return values[layerIdx]
56+
}
57+
}
58+
1859
public struct Gemma3nTextConfiguration: Codable {
1960
let modelType: String
2061
let hiddenSize: Int
2162
let numHiddenLayers: Int
22-
let intermediateSize: Int
63+
let intermediateSize: IntOrArray
2364
let numAttentionHeads: Int
2465
let headDim: Int
2566
let rmsNormEps: Float
2667
let vocabSize: Int
2768
let numKeyValueHeads: Int
2869
let numKvSharedLayers: Int
29-
let queryPreAttnScalar: Float
70+
let queryPreAttnScalar: Float? // Optional - not present in all HF configs
3071
let vocabSizePerLayerInput: Int
3172
let slidingWindow: Int
3273
let maxPositionEmbeddings: Int
@@ -92,14 +133,14 @@ public struct Gemma3nTextConfiguration: Codable {
92133
modelType = try container.decode(String.self, forKey: .modelType)
93134
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
94135
numHiddenLayers = try container.decode(Int.self, forKey: .numHiddenLayers)
95-
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
136+
intermediateSize = try container.decode(IntOrArray.self, forKey: .intermediateSize)
96137
numAttentionHeads = try container.decode(Int.self, forKey: .numAttentionHeads)
97138
headDim = try container.decode(Int.self, forKey: .headDim)
98139
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
99140
vocabSize = try container.decode(Int.self, forKey: .vocabSize)
100141
numKeyValueHeads = try container.decode(Int.self, forKey: .numKeyValueHeads)
101142
numKvSharedLayers = try container.decode(Int.self, forKey: .numKvSharedLayers)
102-
queryPreAttnScalar = try container.decode(Float.self, forKey: .queryPreAttnScalar)
143+
queryPreAttnScalar = try container.decodeIfPresent(Float.self, forKey: .queryPreAttnScalar)
103144
vocabSizePerLayerInput = try container.decode(Int.self, forKey: .vocabSizePerLayerInput)
104145
slidingWindow = try container.decode(Int.self, forKey: .slidingWindow)
105146
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
@@ -309,7 +350,7 @@ class Gemma3nMLP: Module {
309350
init(_ config: Gemma3nTextConfiguration, layerIdx: Int) {
310351
self.config = config
311352
self.hiddenSize = config.hiddenSize
312-
self.intermediateSize = config.intermediateSize
353+
self.intermediateSize = config.intermediateSize[layerIdx]
313354

314355
if let activationSparsityPattern = config.activationSparsityPattern {
315356
self.activationSparsity = activationSparsityPattern[layerIdx]
@@ -963,9 +1004,13 @@ public class Gemma3nTextModel: Module, LLMModel {
9631004

9641005
for (key, value) in weights {
9651006
if key.hasPrefix("model.language_model.") {
1007+
// Remove "model." prefix for VLM-style weights
9661008
let newKey = key.replacingOccurrences(
9671009
of: "model.language_model.", with: "language_model.")
9681010
processedWeights[newKey] = value
1011+
} else {
1012+
// Keep other weights as-is
1013+
processedWeights[key] = value
9691014
}
9701015
}
9711016

0 commit comments

Comments
 (0)