Skip to content

Commit 6c0b66f

Browse files
davidkoski1-ashraful-islamJustinMeans
authored
implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <[email protected]> Co-authored-by: JustinMeans <[email protected]>
1 parent 7e85eb8 commit 6c0b66f

File tree

32 files changed

+3484
-65
lines changed

32 files changed

+3484
-65
lines changed

Applications/LLMEval/ContentView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class LLMEvaluator {
187187
[modelConfiguration] progress in
188188
DispatchQueue.main.sync {
189189
self.modelInfo =
190-
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
190+
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
191191
}
192192
}
193193
self.modelInfo =
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"colors" : [
3+
{
4+
"idiom" : "universal"
5+
}
6+
],
7+
"info" : {
8+
"author" : "xcode",
9+
"version" : 1
10+
}
11+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
{
2+
"images" : [
3+
{
4+
"idiom" : "universal",
5+
"platform" : "ios",
6+
"size" : "1024x1024"
7+
},
8+
{
9+
"idiom" : "mac",
10+
"scale" : "1x",
11+
"size" : "16x16"
12+
},
13+
{
14+
"idiom" : "mac",
15+
"scale" : "2x",
16+
"size" : "16x16"
17+
},
18+
{
19+
"idiom" : "mac",
20+
"scale" : "1x",
21+
"size" : "32x32"
22+
},
23+
{
24+
"idiom" : "mac",
25+
"scale" : "2x",
26+
"size" : "32x32"
27+
},
28+
{
29+
"idiom" : "mac",
30+
"scale" : "1x",
31+
"size" : "128x128"
32+
},
33+
{
34+
"idiom" : "mac",
35+
"scale" : "2x",
36+
"size" : "128x128"
37+
},
38+
{
39+
"idiom" : "mac",
40+
"scale" : "1x",
41+
"size" : "256x256"
42+
},
43+
{
44+
"idiom" : "mac",
45+
"scale" : "2x",
46+
"size" : "256x256"
47+
},
48+
{
49+
"idiom" : "mac",
50+
"scale" : "1x",
51+
"size" : "512x512"
52+
},
53+
{
54+
"idiom" : "mac",
55+
"scale" : "2x",
56+
"size" : "512x512"
57+
}
58+
],
59+
"info" : {
60+
"author" : "xcode",
61+
"version" : 1
62+
}
63+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"info" : {
3+
"author" : "xcode",
4+
"version" : 1
5+
}
6+
}
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import LLM
4+
import MLX
5+
import MLXOptimizers
6+
import MLXRandom
7+
import SwiftUI
8+
import Tokenizers
9+
10+
struct ContentView: View {
11+
12+
@State var evaluator = LoRAEvaluator()
13+
14+
@State var prompt = """
15+
table: 1-10015132-16
16+
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
17+
Q: What is terrence ross' nationality
18+
A:
19+
"""
20+
21+
var body: some View {
22+
VStack {
23+
HStack {
24+
if let progress = evaluator.progress {
25+
if let current = progress.current, let limit = progress.limit {
26+
ProgressView(progress.title, value: current, total: limit)
27+
} else {
28+
ProgressView(progress.title)
29+
}
30+
}
31+
}
32+
.frame(maxWidth: .infinity, minHeight: 25)
33+
34+
VStack {
35+
ScrollView(.vertical) {
36+
ScrollViewReader { sp in
37+
Group {
38+
Text(evaluator.output)
39+
.textSelection(.enabled)
40+
.frame(maxWidth: .infinity)
41+
}
42+
.onChange(of: evaluator.output) { _, _ in
43+
sp.scrollTo("bottom")
44+
}
45+
.padding()
46+
47+
Spacer()
48+
.frame(width: 1, height: 1)
49+
.id("bottom")
50+
}
51+
}
52+
53+
// controls for each of the different states
54+
VStack {
55+
switch evaluator.state {
56+
case .idle:
57+
Button("Start", action: start)
58+
59+
case .training:
60+
EmptyView()
61+
62+
case .evaluate:
63+
Group {
64+
TextEditor(text: $prompt)
65+
.frame(minHeight: 60)
66+
Button("Evaluate", action: evaluate)
67+
}
68+
.disabled(evaluator.progress != nil)
69+
70+
case .failed(let message):
71+
Text("Failed: \(message)")
72+
.bold()
73+
.foregroundStyle(.red)
74+
}
75+
}
76+
.padding()
77+
.frame(maxWidth: .infinity)
78+
}
79+
}
80+
.padding()
81+
}
82+
83+
func start() {
84+
Task {
85+
await evaluator.start()
86+
}
87+
}
88+
89+
func evaluate() {
90+
Task {
91+
await evaluator.evaluate(prompt: prompt)
92+
}
93+
}
94+
}
95+
96+
/// Progress reporting with a title.
97+
struct Progress: Equatable {
98+
let title: String
99+
let current: Double?
100+
let limit: Double?
101+
}
102+
103+
@Observable
104+
class LoRAEvaluator {
105+
106+
enum State {
107+
case idle
108+
case training
109+
case evaluate
110+
case failed(String)
111+
}
112+
113+
enum ModelState {
114+
case idle
115+
case loaded(LLMModel, Tokenizer)
116+
}
117+
118+
var state = State.idle
119+
var progress: Progress?
120+
121+
var output = ""
122+
123+
private let modelConfiguration = ModelConfiguration.mistral7B4bit
124+
private var model: ModelState = .idle
125+
126+
private let loraLayers = 4
127+
private let learningRate: Float = 1e-5
128+
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
129+
130+
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
131+
private let evaluateShowEvery = 8
132+
private let maxTokens = 200
133+
134+
private func loadModel() async throws -> (LLMModel, Tokenizer) {
135+
switch self.model {
136+
case .idle:
137+
let name = modelConfiguration.name
138+
await MainActor.run {
139+
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
140+
}
141+
142+
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
143+
progress in
144+
if progress.fractionCompleted < 1.0 {
145+
DispatchQueue.main.sync {
146+
self.progress = .init(
147+
title: "Download \(name)", current: progress.fractionCompleted,
148+
limit: 1.0)
149+
}
150+
}
151+
}
152+
eval(model)
153+
self.model = .loaded(model, tokenizer)
154+
return (model, tokenizer)
155+
156+
case .loaded(let model, let tokenizer):
157+
return (model, tokenizer)
158+
}
159+
}
160+
161+
private func loadLoRAData(name: String) throws -> [String]? {
162+
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
163+
return try LLM.loadLoRAData(url: url)
164+
}
165+
return nil
166+
}
167+
168+
func start() async {
169+
do {
170+
try await startInner()
171+
} catch {
172+
self.state = .failed("Failed: \(error)")
173+
}
174+
}
175+
176+
private func startInner() async throws {
177+
// setup
178+
GPU.set(cacheLimit: 32 * 1024 * 1024)
179+
await MainActor.run {
180+
output = ""
181+
state = .training
182+
}
183+
184+
// load the model
185+
let (model, tokenizer) = try await loadModel()
186+
187+
// apply LoRA adapters and train
188+
guard let layerProvider = model as? LoRAModel else {
189+
state = .failed("Model must implement the LoRALayerProvider protocol")
190+
return
191+
}
192+
LoRATrain.convert(
193+
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
194+
195+
let train = try loadLoRAData(name: "train")
196+
let valid = try loadLoRAData(name: "valid")
197+
guard let train, let valid else {
198+
state = .failed("Failed to load train/validation data")
199+
return
200+
}
201+
202+
let optimizer = Adam(learningRate: learningRate)
203+
try await LoRATrain.train(
204+
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
205+
parameters: parameters
206+
) { progress in
207+
await MainActor.run {
208+
switch progress {
209+
case .train(let i, _, _, _):
210+
self.progress = .init(
211+
title: "Train", current: Double(i), limit: Double(parameters.iterations))
212+
case .validation:
213+
output += "\n"
214+
default:
215+
break
216+
}
217+
218+
output += progress.description + "\n"
219+
}
220+
221+
return .more
222+
}
223+
224+
// done training, test
225+
await MainActor.run {
226+
self.progress = .init(title: "Testing", current: nil, limit: nil)
227+
}
228+
guard let test = try loadLoRAData(name: "test") else {
229+
state = .failed("Failed to load test data")
230+
return
231+
}
232+
233+
let loss = LoRATrain.evaluate(
234+
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
235+
await MainActor.run {
236+
self.progress = nil
237+
self.output += "\n"
238+
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
239+
self.state = .evaluate
240+
}
241+
}
242+
243+
func evaluate(prompt: String) async {
244+
do {
245+
try await evaluateInner(prompt: prompt)
246+
} catch {
247+
self.state = .failed("Failed: \(error)")
248+
}
249+
}
250+
251+
func evaluateInner(prompt: String) async throws {
252+
await MainActor.run {
253+
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
254+
self.output = ""
255+
}
256+
257+
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
258+
259+
let (model, tokenizer) = try await loadModel()
260+
261+
// prepare the prompt
262+
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
263+
let promptTokens = tokenizer.encode(text: preparedPrompt)
264+
265+
// evaluate
266+
let result = await LLM.generate(
267+
promptTokens: promptTokens, parameters: generateParameters, model: model,
268+
tokenizer: tokenizer,
269+
didGenerate: { tokens in
270+
if tokens.count % evaluateShowEvery == 0 {
271+
let fullOutput = tokenizer.decode(tokens: tokens)
272+
await MainActor.run {
273+
self.output = fullOutput
274+
}
275+
}
276+
return tokens.count >= maxTokens ? .stop : .more
277+
})
278+
279+
await MainActor.run {
280+
self.output = result.output
281+
self.progress = nil
282+
}
283+
}
284+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3+
<plist version="1.0">
4+
<dict>
5+
<key>com.apple.developer.kernel.increased-memory-limit</key>
6+
<true/>
7+
<key>com.apple.security.app-sandbox</key>
8+
<true/>
9+
<key>com.apple.security.files.user-selected.read-only</key>
10+
<true/>
11+
<key>com.apple.security.network.client</key>
12+
<true/>
13+
</dict>
14+
</plist>

0 commit comments

Comments
 (0)