Skip to content

Commit 9d74afd

Browse files
authored
handle partially quantized models (#76)
* handle partially quantized models - fix for #53 #71 #69 #74 - in order to test the models - I added a default prompt of an appropriate form - while working on the model configuration also added additional stop tokens (#74) - fixed the repetitionPenalty code (#71)
1 parent 65f4968 commit 9d74afd

File tree

12 files changed

+139
-67
lines changed

12 files changed

+139
-67
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import Tokenizers
1010

1111
struct ContentView: View {
1212

13-
@State var prompt = "compare python and swift"
13+
@State var prompt = ""
1414
@State var llm = LLMEvaluator()
1515
@Environment(DeviceStat.self) private var deviceStat
1616

@@ -125,6 +125,8 @@ struct ContentView: View {
125125

126126
}
127127
.task {
128+
self.prompt = llm.modelConfiguration.defaultPrompt
129+
128130
// pre-load the weights on launch to speed up the first generation
129131
_ = try? await llm.load()
130132
}
@@ -224,7 +226,7 @@ class LLMEvaluator {
224226

225227
let result = await LLM.generate(
226228
promptTokens: promptTokens, parameters: generateParameters, model: model,
227-
tokenizer: tokenizer
229+
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
228230
) { tokens in
229231
// update the output -- this will make the view show the text as it generates
230232
if tokens.count % displayEveryNTokens == 0 {

Applications/LoRATrainingExample/ContentView.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class LoRAEvaluator {
266266
let result = await LLM.generate(
267267
promptTokens: promptTokens, parameters: generateParameters, model: model,
268268
tokenizer: tokenizer,
269+
extraEOSTokens: modelConfiguration.extraEOSTokens,
269270
didGenerate: { tokens in
270271
if tokens.count % evaluateShowEvery == 0 {
271272
let fullOutput = tokenizer.decode(tokens: tokens)

Libraries/LLM/Evaluate.swift

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
1212
logits = logits.asType(.float32)
1313
}
1414

15-
let probs = softMax(logits / temp, axis: -1)
15+
let probs = softmax(logits / temp, axis: -1)
1616
let sortedIndices = argSort(probs, axis: -1)
1717

1818
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
3131
) -> MLXArray {
3232
if repetitionContext.shape[0] > 0 {
3333
let indices = repetitionContext
34-
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
34+
var selectedLogits = logits[0..., indices]
3535

3636
selectedLogits = MLX.where(
3737
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
100100
if prompt.shape[0] <= parameters.repetitionContextSize {
101101
self.repetitionContext = prompt
102102
} else {
103-
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
103+
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
104104
}
105105
} else {
106106
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
120120
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
121121
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
122122
if parameters.repetitionContextSize > 1 {
123-
repetitionContext = concatenated([repetitionContext, y], axis: 0)
124123
if repetitionContext.shape[0] > parameters.repetitionContextSize {
125-
repetitionContext = repetitionContext[1...]
124+
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
126125
}
127126
}
128127

@@ -174,14 +173,31 @@ public enum GenerateDisposition {
174173
/// - parameters: generation parameters
175174
/// - model: model to evaluate
176175
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
176+
/// - configuration: the model configuration
177177
/// - didGenerate: visitor for the tokens as they are generated
178178
public func generate(
179179
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
180+
extraEOSTokens: Set<String>? = nil,
180181
didGenerate: ([Int]) async -> GenerateDisposition
181182
) async -> GenerateResult {
182183
var start = Date.timeIntervalSinceReferenceDate
183184
var promptTime: TimeInterval = 0
184185

186+
// build a set of additional stop tokens
187+
let additionalEOSTokenIds = Set(
188+
(extraEOSTokens ?? [])
189+
.map {
190+
tokenizer.encode(text: $0)
191+
}
192+
.filter {
193+
// discard anything that is not a single token. sometimes
194+
// the tokenizer will insert a <s> token, so accept that too
195+
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
196+
}
197+
.map {
198+
$0.last!
199+
})
200+
185201
var tokens = [Int]()
186202

187203
for token in TokenIterator(
@@ -196,7 +212,9 @@ public func generate(
196212
}
197213

198214
let t = token.item(Int.self)
199-
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
215+
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
216+
|| additionalEOSTokenIds.contains(t)
217+
{
200218
break
201219
}
202220

Libraries/LLM/LLMModel.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
1212
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
1313
MLXArray, [(MLXArray, MLXArray)]
1414
)
15+
16+
/// Optionally preprocess the weights and modify / remove values as needed.
17+
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
18+
}
19+
20+
extension LLMModel {
21+
22+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
23+
weights
24+
}
25+
1526
}

Libraries/LLM/Llama.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel {
194194
let (out, cache) = model(inputs, cache: cache)
195195
return (lmHead(out), cache)
196196
}
197+
198+
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
199+
// Remove unused precomputed rotary freqs
200+
weights.filter {
201+
!$0.key.contains("self_attn.rotary_emb.inv_freq")
202+
}
203+
}
197204
}
198205

199206
public struct LlamaConfiguration: Codable {

Libraries/LLM/Load.swift

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@ public func load(
5454
}
5555
}
5656

57+
// per-model cleanup
58+
weights = model.sanitize(weights: weights)
59+
5760
// quantize if needed
5861
if let quantization = baseConfig.quantization {
59-
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
62+
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
63+
path, module in
64+
weights["\(path).scales"] != nil
65+
}
6066
}
6167

6268
// apply the loaded weights
@@ -76,38 +82,3 @@ public func load(
7682
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
7783
}
7884
}
79-
80-
// MARK: - Quantization
81-
82-
private func quantizeIfNeeded(
83-
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
84-
) {
85-
86-
func linearPredicate(layer: Module) -> Bool {
87-
if let layer = layer as? Linear {
88-
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
89-
return layer.weight.dim(0) != 8
90-
}
91-
return false
92-
}
93-
94-
var predicate = linearPredicate(layer:)
95-
96-
// for legacy models that don't have lm_head quant due to non-32 dims
97-
if weights["lm_head.scales"] == nil {
98-
let vocabularySize = model.vocabularySize
99-
100-
func vocabularySizePredicate(layer: Module) -> Bool {
101-
if let layer = layer as? Linear {
102-
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
103-
}
104-
return false
105-
}
106-
107-
predicate = vocabularySizePredicate(layer:)
108-
}
109-
110-
QuantizedLinear.quantize(
111-
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
112-
predicate: predicate)
113-
}

Libraries/LLM/Lora.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ public enum LoRATrain {
377377
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
378378
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
379379
/// - fusing with ``fuse(model:layers:deQuantize:)``
380-
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
380+
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)``
381381
/// - note that this is just using normal model text generation
382382
///
383383
/// - Parameters:

Libraries/LLM/Models.swift

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,42 @@ public struct ModelConfiguration {
3333
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
3434
public let overrideTokenizer: String?
3535

36+
/// A reasonable default prompt for the model
37+
public let defaultPrompt: String
38+
39+
/// Additional tokens to use for end of string
40+
public let extraEOSTokens: Set<String>
41+
3642
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
3743
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
3844
/// format
3945
private let preparePrompt: ((String) -> String)?
4046

4147
public init(
4248
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
49+
defaultPrompt: String = "hello",
50+
extraEOSTokens: Set<String> = [],
4351
preparePrompt: ((String) -> String)? = nil
4452
) {
4553
self.id = .id(id)
4654
self.tokenizerId = tokenizerId
4755
self.overrideTokenizer = overrideTokenizer
56+
self.defaultPrompt = defaultPrompt
57+
self.extraEOSTokens = extraEOSTokens
4858
self.preparePrompt = preparePrompt
4959
}
5060

5161
public init(
5262
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
63+
defaultPrompt: String = "hello",
64+
extraEOSTokens: Set<String> = [],
5365
preparePrompt: ((String) -> String)? = nil
5466
) {
5567
self.id = .directory(directory)
5668
self.tokenizerId = tokenizerId
5769
self.overrideTokenizer = overrideTokenizer
70+
self.defaultPrompt = defaultPrompt
71+
self.extraEOSTokens = extraEOSTokens
5872
self.preparePrompt = preparePrompt
5973
}
6074

@@ -98,11 +112,16 @@ public struct ModelConfiguration {
98112
extension ModelConfiguration {
99113

100114
public static let mistral7B4bit = ModelConfiguration(
101-
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
115+
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
116+
117+
// https://www.promptingguide.ai/models/mistral-7b
118+
defaultPrompt: "describe the swift language"
119+
)
102120

103121
public static let codeLlama13b4bit = ModelConfiguration(
104122
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
105-
overrideTokenizer: "PreTrainedTokenizer"
123+
overrideTokenizer: "PreTrainedTokenizer",
124+
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
106125
) { prompt in
107126
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
108127
// the python code produces this (via its custom tokenizer):
@@ -111,40 +130,53 @@ extension ModelConfiguration {
111130
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
112131
}
113132

114-
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
115-
prompt in
116-
"Instruct: \(prompt)\nOutput: "
117-
}
133+
public static let phi4bit = ModelConfiguration(
134+
id: "mlx-community/phi-2-hf-4bit-mlx",
135+
136+
// https://www.promptingguide.ai/models/phi-2
137+
defaultPrompt: "Why is the sky blue?"
138+
)
118139

119140
public static let phi34bit = ModelConfiguration(
120-
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
141+
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
142+
defaultPrompt: "what is the gravity on mars and the moon?",
143+
extraEOSTokens: ["<|end|>"]
121144
) {
122145
prompt in
123146
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
124147
}
125148

126149
public static let gemma2bQuantized = ModelConfiguration(
127150
id: "mlx-community/quantized-gemma-2b-it",
128-
overrideTokenizer: "PreTrainedTokenizer"
151+
overrideTokenizer: "PreTrainedTokenizer",
152+
153+
// https://www.promptingguide.ai/models/gemma
154+
defaultPrompt: "what is the difference between lettuce and cabbage?"
155+
129156
) { prompt in
130157
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
131158
}
132159

133160
public static let qwen205b4bit = ModelConfiguration(
134161
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
135-
overrideTokenizer: "PreTrainedTokenizer"
162+
overrideTokenizer: "PreTrainedTokenizer",
163+
defaultPrompt: "why is the sky blue?"
136164
) { prompt in
137165
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
138166
}
139167

140168
public static let openelm270m4bit = ModelConfiguration(
141-
id: "mlx-community/OpenELM-270M-Instruct"
169+
id: "mlx-community/OpenELM-270M-Instruct",
170+
171+
// https://huggingface.co/apple/OpenELM
172+
defaultPrompt: "Once upon a time there was"
142173
) { prompt in
143174
"\(prompt)"
144175
}
145176

146177
public static let llama38B4bit = ModelConfiguration(
147-
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
178+
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
179+
defaultPrompt: "what is the difference between a fruit and a vegetable?"
148180
) {
149181
prompt in
150182
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"

0 commit comments

Comments
 (0)