Skip to content

Commit b951b78

Browse files
awnidavidkoski
andauthored
phi3 (#54)
* phi3 Co-authored-by: David Koski <[email protected]>
1 parent 6c0b66f commit b951b78

File tree

7 files changed

+284
-8
lines changed

7 files changed

+284
-8
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class LLMEvaluator {
157157

158158
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
159159
/// more devices
160-
let modelConfiguration = ModelConfiguration.phi4bit
160+
let modelConfiguration = ModelConfiguration.phi34bit
161161

162162
/// parameters controlling the output
163163
let generateParameters = GenerateParameters(temperature: 0.6)

Libraries/LLM/Configuration.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public enum ModelType: String, Codable {
3030
case mistral
3131
case llama
3232
case phi
33+
case phi3
3334
case gemma
3435
case qwen2
3536
case starcoder2
@@ -45,6 +46,10 @@ public enum ModelType: String, Codable {
4546
let configuration = try JSONDecoder().decode(
4647
PhiConfiguration.self, from: Data(contentsOf: configuration))
4748
return PhiModel(configuration)
49+
case .phi3:
50+
let configuration = try JSONDecoder().decode(
51+
Phi3Configuration.self, from: Data(contentsOf: configuration))
52+
return Phi3Model(configuration)
4853
case .gemma:
4954
let configuration = try JSONDecoder().decode(
5055
GemmaConfiguration.self, from: Data(contentsOf: configuration))

Libraries/LLM/Evaluate.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,16 @@ public struct GenerateParameters {
6060
public var temperature: Float = 0.6
6161

6262
/// top p sampling
63-
public var topP: Float = 0.9
63+
public var topP: Float = 1.0
6464

6565
/// penalty factor for repeating tokens
66-
public var repetitionPenalty: Float = 1.0
66+
public var repetitionPenalty: Float?
6767

6868
/// number of tokens to consider for repetition penalty
6969
public var repetitionContextSize: Int = 20
7070

7171
public init(
72-
temperature: Float = 0.6, topP: Float = 0.9, repetitionPenalty: Float = 1.0,
72+
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
7373
repetitionContextSize: Int = 20
7474
) {
7575
self.temperature = temperature
@@ -111,11 +111,11 @@ public struct TokenIterator: Sequence, IteratorProtocol {
111111
var logits: MLXArray
112112
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
113113
logits = logits[0..., -1, 0...]
114-
if parameters.repetitionPenalty > 1.0 {
114+
if let repetitionPenalty = parameters.repetitionPenalty {
115115
// apply repetition penalty
116116
logits = applyRepetitionPenalty(
117117
logits: logits, repetitionContext: repetitionContext,
118-
penalty: parameters.repetitionPenalty)
118+
penalty: repetitionPenalty)
119119
}
120120
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
121121
// append the current token to the context and check repetitionPenalty context see if need to remove the first token

Libraries/LLM/Models.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ extension ModelConfiguration {
116116
"Instruct: \(prompt)\nOutput: "
117117
}
118118

119+
public static let phi34bit = ModelConfiguration(
120+
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
121+
) {
122+
prompt in
123+
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
124+
}
125+
119126
public static let gemma2bQuantized = ModelConfiguration(
120127
id: "mlx-community/quantized-gemma-2b-it",
121128
overrideTokenizer: "PreTrainedTokenizer"
@@ -146,6 +153,7 @@ extension ModelConfiguration {
146153
mistral7B4bit,
147154
codeLlama13b4bit,
148155
phi4bit,
156+
phi34bit,
149157
gemma2bQuantized,
150158
qwen205b4bit,
151159
])

Libraries/LLM/Phi3.swift

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import Foundation
4+
import MLX
5+
import MLXFast
6+
import MLXNN
7+
8+
private class Attention: Module {
9+
10+
let args: Phi3Configuration
11+
let scale: Float
12+
13+
@ModuleInfo(key: "qkv_proj") var wqkv: Linear
14+
@ModuleInfo(key: "o_proj") var wo: Linear
15+
16+
let rope: RoPE
17+
18+
public init(_ args: Phi3Configuration) {
19+
self.args = args
20+
21+
let dim = args.hiddenSize
22+
let heads = args.attentionHeads
23+
let kvHeads = args.kvHeads
24+
25+
let headDim = args.hiddenSize / heads
26+
self.scale = pow(Float(headDim), -0.5)
27+
28+
self._wqkv.wrappedValue = Linear(dim, (heads + 2 * kvHeads) * headDim, bias: false)
29+
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
30+
31+
let ropeScale: Float
32+
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
33+
let factor = ropeScaling["factor"]
34+
{
35+
switch factor {
36+
case .string:
37+
fatalError("ropeScaling.factor must be a float")
38+
case .float(let v):
39+
ropeScale = 1 / v
40+
}
41+
} else {
42+
ropeScale = 1
43+
}
44+
45+
self.rope = RoPE(
46+
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
47+
scale: ropeScale)
48+
}
49+
50+
public func callAsFunction(
51+
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
52+
) -> (MLXArray, (MLXArray, MLXArray)) {
53+
let (B, L) = (x.dim(0), x.dim(1))
54+
55+
let qkv = split(wqkv(x), parts: 3, axis: -1)
56+
var queries = qkv[0]
57+
var keys = qkv[1]
58+
var values = qkv[2]
59+
60+
// prepare the queries, keys and values for the attention computation
61+
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
62+
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
63+
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
64+
65+
if let (keyCache, valueCache) = cache {
66+
queries = rope(queries, offset: keyCache.dim(2))
67+
keys = rope(keys, offset: keyCache.dim(2))
68+
keys = concatenated([keyCache, keys], axis: 2)
69+
values = concatenated([valueCache, values], axis: 2)
70+
} else {
71+
queries = rope(queries)
72+
keys = rope(keys)
73+
}
74+
75+
let output = MLXFast.scaledDotProductAttention(
76+
queries: queries, keys: keys, values: values, scale: scale, mask: mask
77+
)
78+
.transposed(0, 2, 1, 3)
79+
.reshaped(B, L, -1)
80+
81+
return (wo(output), (keys, values))
82+
}
83+
}
84+
85+
private class MLP: Module, UnaryLayer {
86+
87+
@ModuleInfo(key: "gate_up_proj") var gate_up: Linear
88+
@ModuleInfo(key: "down_proj") var down: Linear
89+
90+
public init(dimensions: Int, hiddenDimensions: Int) {
91+
self._gate_up.wrappedValue = Linear(dimensions, 2 * hiddenDimensions, bias: false)
92+
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
93+
}
94+
95+
public func callAsFunction(_ x: MLXArray) -> MLXArray {
96+
let gu = split(gate_up(x), parts: 2, axis: -1)
97+
return down(silu(gu[0]) * gu[1])
98+
}
99+
}
100+
101+
private class TransformerBlock: Module {
102+
103+
@ModuleInfo(key: "self_attn") var attention: Attention
104+
let mlp: MLP
105+
106+
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
107+
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
108+
109+
public init(_ args: Phi3Configuration) {
110+
self._attention.wrappedValue = Attention(args)
111+
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
112+
self._inputLayerNorm.wrappedValue = RMSNorm(
113+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
114+
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
115+
dimensions: args.hiddenSize, eps: args.rmsNormEps)
116+
}
117+
118+
public func callAsFunction(
119+
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
120+
) -> (MLXArray, (MLXArray, MLXArray)) {
121+
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
122+
let h = x + r
123+
r = mlp(postAttentionLayerNorm(h))
124+
let out = h + r
125+
return (out, cache)
126+
}
127+
}
128+
129+
public class Phi3ModelInner: Module {
130+
131+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
132+
133+
fileprivate let layers: [TransformerBlock]
134+
let norm: RMSNorm
135+
136+
public init(_ args: Phi3Configuration) {
137+
precondition(args.vocabularySize > 0)
138+
139+
self._embedTokens.wrappedValue = Embedding(
140+
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
141+
142+
self.layers = (0 ..< args.hiddenLayers)
143+
.map { _ in
144+
TransformerBlock(args)
145+
}
146+
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
147+
}
148+
149+
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
150+
MLXArray, [(MLXArray, MLXArray)]
151+
) {
152+
var h = embedTokens(inputs)
153+
154+
var mask: MLXArray? = nil
155+
if h.dim(1) > 1 {
156+
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
157+
mask = mask?.asType(h.dtype)
158+
}
159+
160+
var newCache = [(MLXArray, MLXArray)]()
161+
162+
for (i, layer) in layers.enumerated() {
163+
var cacheUpdate: (MLXArray, MLXArray)
164+
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
165+
newCache.append(cacheUpdate)
166+
}
167+
168+
return (norm(h), newCache)
169+
}
170+
}
171+
172+
public class Phi3Model: Module, LLMModel {
173+
174+
public let vocabularySize: Int
175+
let model: Phi3ModelInner
176+
177+
@ModuleInfo(key: "lm_head") var lmHead: Linear
178+
179+
public init(_ args: Phi3Configuration) {
180+
self.vocabularySize = args.vocabularySize
181+
self.model = Phi3ModelInner(args)
182+
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
183+
}
184+
185+
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
186+
MLXArray, [(MLXArray, MLXArray)]
187+
) {
188+
let (out, cache) = model(inputs, cache: cache)
189+
return (lmHead(out), cache)
190+
}
191+
}
192+
193+
public struct Phi3Configuration: Codable {
194+
195+
var hiddenSize: Int
196+
var hiddenLayers: Int
197+
var intermediateSize: Int
198+
var attentionHeads: Int
199+
var rmsNormEps: Float
200+
var vocabularySize: Int
201+
var kvHeads: Int
202+
var ropeTheta: Float = 10_000
203+
var ropeTraditional: Bool = false
204+
var ropeScaling: [String: StringOrNumber]? = nil
205+
206+
enum CodingKeys: String, CodingKey {
207+
case hiddenSize = "hidden_size"
208+
case hiddenLayers = "num_hidden_layers"
209+
case intermediateSize = "intermediate_size"
210+
case attentionHeads = "num_attention_heads"
211+
case rmsNormEps = "rms_norm_eps"
212+
case vocabularySize = "vocab_size"
213+
case kvHeads = "num_key_value_heads"
214+
case ropeTheta = "rope_theta"
215+
case ropeTraditional = "rope_traditional"
216+
case ropeScaling = "rope_scaling"
217+
}
218+
219+
public init(from decoder: Decoder) throws {
220+
// custom implementation to handle optional keys with required values
221+
let container: KeyedDecodingContainer<Phi3Configuration.CodingKeys> =
222+
try decoder.container(
223+
keyedBy: Phi3Configuration.CodingKeys.self)
224+
225+
self.hiddenSize = try container.decode(
226+
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenSize)
227+
self.hiddenLayers = try container.decode(
228+
Int.self, forKey: Phi3Configuration.CodingKeys.hiddenLayers)
229+
self.intermediateSize = try container.decode(
230+
Int.self, forKey: Phi3Configuration.CodingKeys.intermediateSize)
231+
self.attentionHeads = try container.decode(
232+
Int.self, forKey: Phi3Configuration.CodingKeys.attentionHeads)
233+
self.rmsNormEps = try container.decode(
234+
Float.self, forKey: Phi3Configuration.CodingKeys.rmsNormEps)
235+
self.vocabularySize = try container.decode(
236+
Int.self, forKey: Phi3Configuration.CodingKeys.vocabularySize)
237+
self.kvHeads = try container.decode(Int.self, forKey: Phi3Configuration.CodingKeys.kvHeads)
238+
self.ropeTheta =
239+
try container.decodeIfPresent(
240+
Float.self, forKey: Phi3Configuration.CodingKeys.ropeTheta)
241+
?? 10_000
242+
self.ropeTraditional =
243+
try container.decodeIfPresent(
244+
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
245+
self.ropeScaling = try container.decodeIfPresent(
246+
[String: StringOrNumber].self, forKey: Phi3Configuration.CodingKeys.ropeScaling)
247+
248+
}
249+
}
250+
251+
// MARK: - LoRA
252+
253+
extension Phi3Model: LoRAModel {
254+
public func loraLinearLayers() -> LoRALinearLayers {
255+
model.layers.map { ($0.attention, ["qkv_proj"]) }
256+
}
257+
}

Tools/llm-tool/LLMTool.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ struct GenerateArguments: ParsableArguments {
5353
var temperature: Float = 0.6
5454

5555
@Option(name: .long, help: "The top p sampling")
56-
var topP: Float = 0.9
56+
var topP: Float = 1.0
5757

5858
@Option(name: .long, help: "The penalty factor for repeating tokens")
59-
var repetitionPenalty: Float = 1.0
59+
var repetitionPenalty: Float?
6060

6161
@Option(name: .long, help: "The number of tokens to consider for repetition penalty")
6262
var repetitionContextSize: Int = 20

0 commit comments

Comments
 (0)