Skip to content

Commit 3021ed0

Browse files
committed
feat: add lora support
1 parent ce816dd commit 3021ed0

File tree

2 files changed

+125
-17
lines changed

2 files changed

+125
-17
lines changed

Sources/FLUX.swift

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,47 @@ import Tokenizers
77

88
open class FLUX {
99

10+
internal func loadLoraWeights(hub: HubApi, loraPath: String, dType: DType) async throws
11+
-> [String: MLXArray]
12+
{
13+
let loraDirectory: URL
14+
if FileManager.default.fileExists(atPath: loraPath) {
15+
loraDirectory = URL(fileURLWithPath: loraPath)
16+
} else {
17+
let repo = Hub.Repo(id: loraPath)
18+
try await hub.snapshot(from: repo, matching: ["*.safetensors"])
19+
loraDirectory = hub.localRepoLocation(repo)
20+
}
21+
22+
return try Self.loadLoraWeights(directory: loraDirectory, dType: dType)
23+
}
24+
25+
internal static func loadLoraWeights(directory: URL, dType: DType) throws -> [String: MLXArray] {
26+
var loraWeights = [String: MLXArray]()
27+
let enumerator = FileManager.default.enumerator(
28+
at: directory, includingPropertiesForKeys: nil)!
29+
for case let url as URL in enumerator {
30+
if url.pathExtension == "safetensors" {
31+
let w = try loadArrays(url: url)
32+
for (key, value) in w {
33+
let newKey = remapWeightKey(key)
34+
if value.dtype != .bfloat16 {
35+
loraWeights[newKey] = value.asType(dType)
36+
} else {
37+
loraWeights[newKey] = value
38+
}
39+
}
40+
}
41+
}
42+
return loraWeights
43+
}
44+
1045
internal static func remapWeightKey(_ key: String) -> String {
11-
if (key.contains(".ff.") || key.contains(".ff_context.")) {
46+
if key.contains(".ff.") || key.contains(".ff_context.") {
1247
let components = key.components(separatedBy: ".")
1348
if components.count >= 5 {
1449
let blockIndex = components[1]
15-
let ffType = components[2] // "ff" or "ff_context"
50+
let ffType = components[2] // "ff" or "ff_context"
1651
let netIndex = components[4]
1752

1853
if netIndex == "0" {
@@ -89,7 +124,7 @@ open class FLUX {
89124
{
90125
t5Weights["relative_attention_bias.weight"] = relativeAttentionBias
91126
}
92-
127+
93128
t5Encoder.update(parameters: ModuleParameters.unflattened(t5Weights))
94129
return t5Encoder
95130
}
@@ -112,8 +147,8 @@ open class FLUX {
112147
public class Flux1Schnell: FLUX, TextToImageGenerator {
113148
let clipTokenizer: CLIPTokenizer
114149
let t5Tokenizer: any Tokenizer
115-
let vae: VAE
116150
let transformer: MultiModalDiffusionTransformer
151+
let vae: VAE
117152
let t5Encoder: T5Encoder
118153
let clipEncoder: CLIPEncoder
119154

@@ -154,7 +189,12 @@ public class Flux1Schnell: FLUX, TextToImageGenerator {
154189

155190
public func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator {
156191
let latentsShape = [1, (parameters.height / 16) * (parameters.width / 16), 64]
157-
let latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(parameters.seed))
192+
let latents: MLXArray
193+
if let seed = parameters.seed {
194+
latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(seed))
195+
} else {
196+
latents = MLXRandom.normal(latentsShape)
197+
}
158198
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
159199

160200
return DenoiseIterator(
@@ -203,8 +243,8 @@ public class Flux1Schnell: FLUX, TextToImageGenerator {
203243
public class Flux1Dev: FLUX, TextToImageGenerator {
204244
let clipTokenizer: CLIPTokenizer
205245
let t5Tokenizer: any Tokenizer
206-
let vae: VAE
207246
let transformer: MultiModalDiffusionTransformer
247+
let vae: VAE
208248
let t5Encoder: T5Encoder
209249
let clipEncoder: CLIPEncoder
210250

@@ -222,7 +262,8 @@ public class Flux1Dev: FLUX, TextToImageGenerator {
222262
private static func loadTransformer(directory: URL, dType: DType) throws
223263
-> MultiModalDiffusionTransformer
224264
{
225-
let transformer = MultiModalDiffusionTransformer(MultiModalDiffusionConfiguration(guidanceEmbeds: true))
265+
let transformer = MultiModalDiffusionTransformer(
266+
MultiModalDiffusionConfiguration(guidanceEmbeds: true))
226267
var transformerWeights = [String: MLXArray]()
227268
let enumerator = FileManager.default.enumerator(
228269
at: directory.appending(path: "transformer"), includingPropertiesForKeys: nil)!
@@ -245,7 +286,12 @@ public class Flux1Dev: FLUX, TextToImageGenerator {
245286

246287
public func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator {
247288
let latentsShape = [1, (parameters.height / 16) * (parameters.width / 16), 64]
248-
let latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(parameters.seed))
289+
let latents: MLXArray
290+
if let seed = parameters.seed {
291+
latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(seed))
292+
} else {
293+
latents = MLXRandom.normal(latentsShape)
294+
}
249295
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
250296

251297
return DenoiseIterator(
@@ -348,4 +394,4 @@ public struct DenoiseIterator: Sequence, IteratorProtocol {
348394
i += 1
349395
return latents
350396
}
351-
}
397+
}

Sources/FluxConfiguration.swift

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ public struct LoadConfiguration: Sendable {
1111
/// quantize weights
1212
public var quantize = false
1313

14+
public let loraPath: String?
1415
public var dType: DType {
1516
float16 ? .float16 : .float32
1617
}
1718

18-
public init(float16: Bool = true, quantize: Bool = false) {
19+
public init(
20+
float16: Bool = true, quantize: Bool = false, loraPath: String? = nil
21+
) {
1922
self.float16 = float16
2023
self.quantize = quantize
24+
self.loraPath = loraPath
2125
}
2226
}
2327

@@ -26,14 +30,20 @@ public struct EvaluateParameters {
2630
public var height: Int
2731
public var numInferenceSteps: Int
2832
public var guidance: Float
29-
public var seed: UInt64
33+
public var seed: UInt64?
3034
public var prompt: String
3135
public var numTrainSteps: Int
3236
public let sigmas: MLXArray
3337

3438
public init(
35-
numInferenceSteps: Int = 4, width: Int = 1024, height: Int = 1024, guidance: Float = 4.0,
36-
seed: UInt64 = 0, prompt: String = "", numTrainSteps: Int = 1000, shiftSigmas: Bool = false
39+
width: Int = 512,
40+
height: Int = 512,
41+
numInferenceSteps: Int = 4,
42+
guidance: Float = 4.0,
43+
seed: UInt64? = nil,
44+
prompt: String = "",
45+
numTrainSteps: Int = 1000,
46+
shiftSigmas: Bool = false
3747
) {
3848
if width % 16 != 0 || height % 16 != 0 {
3949
print("Warning: Width and height should be multiples of 16. Rounding down.")
@@ -77,12 +87,36 @@ enum FileKey {
7787
case tokenizer2
7888
}
7989

90+
// TODO: add support for mlx flux fine-tuning
91+
func fuseLoraWeights(
92+
transform: Module, transformerWeight: [String: MLXArray], loraWeight: [String: MLXArray]
93+
) -> [String: MLXArray] {
94+
var fusedWeights = transformerWeight
95+
96+
for (key, value) in transform.namedModules() {
97+
if let _ = value as? Linear {
98+
let loraAKey = "transformer." + key + ".lora_A.weight"
99+
let loraBKey = "transformer." + key + ".lora_B.weight"
100+
let weightKey = key + ".weight"
101+
102+
if let loraA = loraWeight[loraAKey], let loraB = loraWeight[loraBKey],
103+
let transformerWeight = fusedWeights[weightKey]
104+
{
105+
let loraScale: Float = 1.0
106+
let loraFused = MLX.matmul(loraB, loraA)
107+
fusedWeights[weightKey] = transformerWeight + loraScale * loraFused
108+
}
109+
}
110+
}
111+
return fusedWeights
112+
}
113+
80114
public struct FluxConfiguration: Sendable {
81-
public let id: String
115+
public var id: String
82116
let files: [FileKey: String]
83117
public let defaultParameters: @Sendable () -> EvaluateParameters
84118
let factory:
85-
@Sendable (HubApi, FluxConfiguration, LoadConfiguration) throws ->
119+
@Sendable (HubApi, FluxConfiguration, LoadConfiguration) async throws ->
86120
FLUX
87121

88122
public func download(
@@ -94,9 +128,9 @@ public struct FluxConfiguration: Sendable {
94128
}
95129

96130
public func textToImageGenerator(hub: HubApi = HubApi(), configuration: LoadConfiguration)
97-
throws -> TextToImageGenerator?
131+
async throws -> TextToImageGenerator?
98132
{
99-
try factory(hub, self, configuration) as? TextToImageGenerator
133+
try await factory(hub, self, configuration) as? TextToImageGenerator
100134
}
101135

102136
public static let flux1Schnell = FluxConfiguration(
@@ -113,6 +147,20 @@ public struct FluxConfiguration: Sendable {
113147
factory: { hub, fluxConfiguration, loadConfiguration in
114148
let flux = try Flux1Schnell(
115149
hub: hub, configuration: fluxConfiguration, dType: loadConfiguration.dType)
150+
151+
if let loraPath = loadConfiguration.loraPath {
152+
let loraWeight = try await flux.loadLoraWeights(
153+
hub: hub, loraPath: loraPath, dType: loadConfiguration.dType)
154+
155+
let weights = fuseLoraWeights(
156+
transform: flux.transformer,
157+
transformerWeight: Dictionary(
158+
uniqueKeysWithValues: flux.transformer.parameters().flattened()), loraWeight: loraWeight
159+
)
160+
161+
flux.transformer.update(parameters: ModuleParameters.unflattened(weights))
162+
}
163+
116164
if loadConfiguration.quantize {
117165
quantize(model: flux.clipEncoder, filter: { k, m in m is Linear })
118166
quantize(model: flux.t5Encoder, filter: { k, m in m is Linear })
@@ -141,6 +189,20 @@ public struct FluxConfiguration: Sendable {
141189
factory: { hub, fluxConfiguration, loadConfiguration in
142190
let flux = try Flux1Dev(
143191
hub: hub, configuration: fluxConfiguration, dType: loadConfiguration.dType)
192+
193+
if let loraPath = loadConfiguration.loraPath {
194+
let loraWeight = try await flux.loadLoraWeights(
195+
hub: hub, loraPath: loraPath, dType: loadConfiguration.dType)
196+
197+
let weights = fuseLoraWeights(
198+
transform: flux.transformer,
199+
transformerWeight: Dictionary(
200+
uniqueKeysWithValues: flux.transformer.parameters().flattened()), loraWeight: loraWeight
201+
)
202+
203+
flux.transformer.update(parameters: ModuleParameters.unflattened(weights))
204+
}
205+
144206
if loadConfiguration.quantize {
145207
quantize(model: flux.clipEncoder, filter: { k, m in m is Linear })
146208
quantize(model: flux.t5Encoder, filter: { k, m in m is Linear })

0 commit comments

Comments
 (0)