Skip to content

Commit 6ff3c1f

Browse files
authored
feat: add Qwen3 embedding (#402)
* feat: add Qwen3 embedding model implementation
1 parent f913e4c commit 6ff3c1f

File tree

6 files changed

+300
-20
lines changed

6 files changed

+300
-20
lines changed

Libraries/Embedders/Configuration.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ private class ModelTypeRegistry: @unchecked Sendable {
6969
let model = NomicBertModel(configuration)
7070
return model
7171
},
72+
"qwen3": {
73+
url in
74+
let configuration = try JSONDecoder().decode(
75+
Qwen3Configuration.self, from: Data(contentsOf: url))
76+
let model = Qwen3Model(configuration)
77+
return model
78+
},
7279
]
7380

7481
public func registerModelType(

Libraries/Embedders/Models.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ extension ModelConfiguration {
108108
public static let bge_m3 = ModelConfiguration(id: "BAAI/bge-m3")
109109
public static let mixedbread_large = ModelConfiguration(
110110
id: "mixedbread-ai/mxbai-embed-large-v1")
111+
public static let qwen3_embedding = ModelConfiguration(
112+
id: "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ")
111113

112114
private enum BootstrapState: Sendable {
113115
case idle
@@ -138,6 +140,7 @@ extension ModelConfiguration {
138140
snowflake_lg,
139141
bge_m3,
140142
mixedbread_large,
143+
qwen3_embedding,
141144
])
142145
bootstrapState = .bootstrapped
143146

Libraries/Embedders/NomicBert.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ private class Encoder: Module {
312312

313313
func callAsFunction(_ inputs: MLXArray, attentionMask: MLXArray? = nil) -> MLXArray {
314314
var outputs = inputs
315-
for (index, layer) in layers.enumerated() {
315+
for layer in layers {
316316
outputs = layer(outputs, mask: attentionMask)
317317
}
318318
return outputs

Libraries/Embedders/Qwen3.swift

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
import MLX
5+
import MLXFast
6+
import MLXLMCommon
7+
import MLXNN
8+
9+
private class Attention: Module {
10+
let args: Qwen3Configuration
11+
let scale: Float
12+
13+
@ModuleInfo(key: "q_proj") var wq: Linear
14+
@ModuleInfo(key: "k_proj") var wk: Linear
15+
@ModuleInfo(key: "v_proj") var wv: Linear
16+
@ModuleInfo(key: "o_proj") var wo: Linear
17+
18+
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
19+
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
20+
21+
let rope: RoPE
22+
23+
public init(_ args: Qwen3Configuration) {
24+
self.args = args
25+
26+
let dim = args.hiddenSize
27+
let heads = args.attentionHeads
28+
let kvHeads = args.kvHeads
29+
30+
let headDim = args.headDim
31+
self.scale = Float(pow(Double(headDim), -0.5))
32+
33+
_wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
34+
_wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
35+
_wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
36+
_wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
37+
38+
_qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
39+
_kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps)
40+
41+
var ropeScale: Float = 1
42+
if let ropeScaling = args.ropeScaling,
43+
let typeValue = ropeScaling["type"],
44+
case .string(let type) = typeValue, type == "linear",
45+
let factorValue = ropeScaling["factor"]
46+
{
47+
switch factorValue {
48+
case .float(let v):
49+
ropeScale = 1 / v
50+
case .string(let s) where Float(s) != nil:
51+
ropeScale = 1 / Float(s)!
52+
default:
53+
break
54+
}
55+
}
56+
57+
self.rope = RoPE(
58+
dimensions: headDim, traditional: false, base: args.ropeTheta,
59+
scale: ropeScale)
60+
}
61+
62+
public func callAsFunction(
63+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
64+
) -> MLXArray {
65+
let (B, L) = (x.dim(0), x.dim(1))
66+
67+
var queries = wq(x)
68+
var keys = wk(x)
69+
var values = wv(x)
70+
71+
// prepare the queries, keys and values for the attention computation
72+
queries = qNorm(queries.reshaped(B, L, args.attentionHeads, -1)).transposed(0, 2, 1, 3)
73+
keys = kNorm(keys.reshaped(B, L, args.kvHeads, -1)).transposed(0, 2, 1, 3)
74+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
75+
76+
if let cache {
77+
queries = rope(queries, offset: cache.offset)
78+
keys = rope(keys, offset: cache.offset)
79+
(keys, values) = cache.update(keys: keys, values: values)
80+
} else {
81+
queries = rope(queries)
82+
keys = rope(keys)
83+
}
84+
85+
let output = MLXFast.scaledDotProductAttention(
86+
queries: queries, keys: keys, values: values, scale: scale, mask: mask
87+
)
88+
.transposed(0, 2, 1, 3)
89+
.reshaped(B, L, -1)
90+
91+
return wo(output)
92+
}
93+
}
94+
95+
private class MLP: Module, UnaryLayer {
96+
@ModuleInfo(key: "gate_proj") var gate: Linear
97+
@ModuleInfo(key: "down_proj") var down: Linear
98+
@ModuleInfo(key: "up_proj") var up: Linear
99+
100+
public init(dimensions: Int, hiddenDimensions: Int) {
101+
_gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
102+
_down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
103+
_up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
104+
}
105+
106+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
107+
down(silu(gate(x)) * up(x))
108+
}
109+
}
110+
111+
private class TransformerBlock: Module {
112+
@ModuleInfo(key: "self_attn") var attention: Attention
113+
let mlp: MLP
114+
115+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
116+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
117+
118+
public init(_ args: Qwen3Configuration) {
119+
_attention.wrappedValue = Attention(args)
120+
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
121+
_inputLayerNorm.wrappedValue = RMSNorm(
122+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
123+
_postAttentionLayerNorm.wrappedValue = RMSNorm(
124+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
125+
}
126+
127+
public func callAsFunction(
128+
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
129+
) -> MLXArray {
130+
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
131+
let h = x + r
132+
r = mlp(postAttentionLayerNorm(h))
133+
let out = h + r
134+
return out
135+
}
136+
}
137+
138+
private class Qwen3ModelInner: Module {
139+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
140+
141+
fileprivate let layers: [TransformerBlock]
142+
let norm: RMSNorm
143+
144+
public init(_ args: Qwen3Configuration) {
145+
precondition(args.vocabularySize > 0)
146+
147+
_embedTokens.wrappedValue = Embedding(
148+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
149+
150+
self.layers = (0 ..< args.hiddenLayers)
151+
.map { _ in
152+
TransformerBlock(args)
153+
}
154+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
155+
}
156+
157+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
158+
var h = embedTokens(inputs)
159+
160+
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
161+
162+
for (i, layer) in layers.enumerated() {
163+
h = layer(h, mask: mask, cache: cache?[i])
164+
}
165+
166+
return norm(h)
167+
}
168+
}
169+
170+
public class Qwen3Model: Module, EmbeddingModel {
171+
public let vocabularySize: Int
172+
public let kvHeads: [Int]
173+
174+
@ModuleInfo(key: "model") private var model: Qwen3ModelInner
175+
let configuration: Qwen3Configuration
176+
177+
public init(_ args: Qwen3Configuration) {
178+
self.configuration = args
179+
self.vocabularySize = args.vocabularySize
180+
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
181+
self._model.wrappedValue = Qwen3ModelInner(args)
182+
}
183+
184+
public func callAsFunction(
185+
_ inputIds: MLXArray, positionIds: MLXArray? = nil, tokenTypeIds: MLXArray? = nil,
186+
attentionMask: MLXArray? = nil
187+
)
188+
-> EmbeddingModelOutput
189+
{
190+
let out = model(inputIds, cache: nil)
191+
return EmbeddingModelOutput(
192+
hiddenStates: out,
193+
pooledOutput: nil)
194+
}
195+
196+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
197+
var sanitizedWeights = [String: MLXArray]()
198+
199+
for (key, value) in weights {
200+
// Skip unused keys
201+
if key.contains("self_attn.rotary_emb.inv_freq") || key.contains("lm_head") {
202+
continue
203+
}
204+
205+
var newKey = key
206+
if !newKey.hasPrefix("model.") {
207+
newKey = "model." + newKey
208+
}
209+
210+
sanitizedWeights[newKey] = value
211+
}
212+
213+
return sanitizedWeights
214+
}
215+
}
216+
217+
public struct Qwen3Configuration: Codable, Sendable {
218+
var hiddenSize: Int
219+
var hiddenLayers: Int
220+
var intermediateSize: Int
221+
var attentionHeads: Int
222+
var rmsNormEps: Float
223+
var vocabularySize: Int
224+
var kvHeads: Int
225+
var ropeTheta: Float = 1_000_000
226+
var headDim: Int
227+
var ropeScaling: [String: StringOrNumber]? = nil
228+
var tieWordEmbeddings = false
229+
var maxPositionEmbeddings: Int = 32768
230+
231+
enum CodingKeys: String, CodingKey {
232+
case hiddenSize = "hidden_size"
233+
case hiddenLayers = "num_hidden_layers"
234+
case intermediateSize = "intermediate_size"
235+
case attentionHeads = "num_attention_heads"
236+
case rmsNormEps = "rms_norm_eps"
237+
case vocabularySize = "vocab_size"
238+
case kvHeads = "num_key_value_heads"
239+
case ropeTheta = "rope_theta"
240+
case headDim = "head_dim"
241+
case ropeScaling = "rope_scaling"
242+
case tieWordEmbeddings = "tie_word_embeddings"
243+
case maxPositionEmbeddings = "max_position_embeddings"
244+
}
245+
246+
public init(from decoder: Decoder) throws {
247+
// custom implementation to handle optional keys with required values
248+
let container: KeyedDecodingContainer<Qwen3Configuration.CodingKeys> =
249+
try decoder.container(
250+
keyedBy: Qwen3Configuration.CodingKeys.self)
251+
252+
self.hiddenSize = try container.decode(
253+
Int.self, forKey: Qwen3Configuration.CodingKeys.hiddenSize)
254+
self.hiddenLayers = try container.decode(
255+
Int.self, forKey: Qwen3Configuration.CodingKeys.hiddenLayers)
256+
self.intermediateSize = try container.decode(
257+
Int.self, forKey: Qwen3Configuration.CodingKeys.intermediateSize)
258+
self.attentionHeads = try container.decode(
259+
Int.self, forKey: Qwen3Configuration.CodingKeys.attentionHeads)
260+
self.rmsNormEps = try container.decode(
261+
Float.self, forKey: Qwen3Configuration.CodingKeys.rmsNormEps)
262+
self.vocabularySize = try container.decode(
263+
Int.self, forKey: Qwen3Configuration.CodingKeys.vocabularySize)
264+
self.kvHeads = try container.decode(Int.self, forKey: Qwen3Configuration.CodingKeys.kvHeads)
265+
self.ropeTheta =
266+
try container.decodeIfPresent(
267+
Float.self, forKey: Qwen3Configuration.CodingKeys.ropeTheta)
268+
?? 1_000_000
269+
self.headDim = try container.decode(
270+
Int.self, forKey: Qwen3Configuration.CodingKeys.headDim)
271+
self.ropeScaling = try container.decodeIfPresent(
272+
[String: StringOrNumber].self, forKey: Qwen3Configuration.CodingKeys.ropeScaling)
273+
self.tieWordEmbeddings =
274+
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
275+
self.maxPositionEmbeddings =
276+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
277+
}
278+
}

Package.resolved

Lines changed: 10 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ let package = Package(
118118
.product(name: "MLXNN", package: "mlx-swift"),
119119
.product(name: "Transformers", package: "swift-transformers"),
120120
.product(name: "MLXLinalg", package: "mlx-swift"),
121+
.target(name: "MLXLMCommon"),
121122
],
122123
path: "Libraries/Embedders",
123124
exclude: [

0 commit comments

Comments
 (0)