@@ -11,13 +11,17 @@ public struct LoadConfiguration: Sendable {
1111 /// quantize weights
1212 public var quantize = false
1313
14+ public let loraPath : String ?
1415 public var dType : DType {
1516 float16 ? . float16 : . float32
1617 }
1718
18- public init ( float16: Bool = true , quantize: Bool = false ) {
19+ public init (
20+ float16: Bool = true , quantize: Bool = false , loraPath: String ? = nil
21+ ) {
1922 self . float16 = float16
2023 self . quantize = quantize
24+ self . loraPath = loraPath
2125 }
2226}
2327
@@ -26,14 +30,20 @@ public struct EvaluateParameters {
2630 public var height : Int
2731 public var numInferenceSteps : Int
2832 public var guidance : Float
29- public var seed : UInt64
33+ public var seed : UInt64 ?
3034 public var prompt : String
3135 public var numTrainSteps : Int
3236 public let sigmas : MLXArray
3337
3438 public init (
35- numInferenceSteps: Int = 4 , width: Int = 1024 , height: Int = 1024 , guidance: Float = 4.0 ,
36- seed: UInt64 = 0 , prompt: String = " " , numTrainSteps: Int = 1000 , shiftSigmas: Bool = false
39+ width: Int = 512 ,
40+ height: Int = 512 ,
41+ numInferenceSteps: Int = 4 ,
42+ guidance: Float = 4.0 ,
43+ seed: UInt64 ? = nil ,
44+ prompt: String = " " ,
45+ numTrainSteps: Int = 1000 ,
46+ shiftSigmas: Bool = false
3747 ) {
3848 if width % 16 != 0 || height % 16 != 0 {
3949 print ( " Warning: Width and height should be multiples of 16. Rounding down. " )
@@ -77,12 +87,36 @@ enum FileKey {
7787 case tokenizer2
7888}
7989
90+ // TODO: add support for mlx flux fine-tuning
91+ func fuseLoraWeights(
92+ transform: Module , transformerWeight: [ String : MLXArray ] , loraWeight: [ String : MLXArray ]
93+ ) -> [ String : MLXArray ] {
94+ var fusedWeights = transformerWeight
95+
96+ for (key, value) in transform. namedModules ( ) {
97+ if let _ = value as? Linear {
98+ let loraAKey = " transformer. " + key + " .lora_A.weight "
99+ let loraBKey = " transformer. " + key + " .lora_B.weight "
100+ let weightKey = key + " .weight "
101+
102+ if let loraA = loraWeight [ loraAKey] , let loraB = loraWeight [ loraBKey] ,
103+ let transformerWeight = fusedWeights [ weightKey]
104+ {
105+ let loraScale : Float = 1.0
106+ let loraFused = MLX . matmul ( loraB, loraA)
107+ fusedWeights [ weightKey] = transformerWeight + loraScale * loraFused
108+ }
109+ }
110+ }
111+ return fusedWeights
112+ }
113+
80114public struct FluxConfiguration : Sendable {
81- public let id : String
115+ public var id : String
82116 let files : [ FileKey : String ]
83117 public let defaultParameters : @Sendable ( ) -> EvaluateParameters
84118 let factory :
85- @Sendable ( HubApi, FluxConfiguration, LoadConfiguration) throws ->
119+ @Sendable ( HubApi, FluxConfiguration, LoadConfiguration) async throws ->
86120 FLUX
87121
88122 public func download(
@@ -94,9 +128,9 @@ public struct FluxConfiguration: Sendable {
94128 }
95129
96130 public func textToImageGenerator( hub: HubApi = HubApi ( ) , configuration: LoadConfiguration )
97- throws -> TextToImageGenerator ?
131+ async throws -> TextToImageGenerator ?
98132 {
99- try factory ( hub, self , configuration) as? TextToImageGenerator
133+ try await factory ( hub, self , configuration) as? TextToImageGenerator
100134 }
101135
102136 public static let flux1Schnell = FluxConfiguration (
@@ -113,6 +147,20 @@ public struct FluxConfiguration: Sendable {
113147 factory: { hub, fluxConfiguration, loadConfiguration in
114148 let flux = try Flux1Schnell (
115149 hub: hub, configuration: fluxConfiguration, dType: loadConfiguration. dType)
150+
151+ if let loraPath = loadConfiguration. loraPath {
152+ let loraWeight = try await flux. loadLoraWeights (
153+ hub: hub, loraPath: loraPath, dType: loadConfiguration. dType)
154+
155+ let weights = fuseLoraWeights (
156+ transform: flux. transformer,
157+ transformerWeight: Dictionary (
158+ uniqueKeysWithValues: flux. transformer. parameters ( ) . flattened ( ) ) , loraWeight: loraWeight
159+ )
160+
161+ flux. transformer. update ( parameters: ModuleParameters . unflattened ( weights) )
162+ }
163+
116164 if loadConfiguration. quantize {
117165 quantize ( model: flux. clipEncoder, filter: { k, m in m is Linear } )
118166 quantize ( model: flux. t5Encoder, filter: { k, m in m is Linear } )
@@ -141,6 +189,20 @@ public struct FluxConfiguration: Sendable {
141189 factory: { hub, fluxConfiguration, loadConfiguration in
142190 let flux = try Flux1Dev (
143191 hub: hub, configuration: fluxConfiguration, dType: loadConfiguration. dType)
192+
193+ if let loraPath = loadConfiguration. loraPath {
194+ let loraWeight = try await flux. loadLoraWeights (
195+ hub: hub, loraPath: loraPath, dType: loadConfiguration. dType)
196+
197+ let weights = fuseLoraWeights (
198+ transform: flux. transformer,
199+ transformerWeight: Dictionary (
200+ uniqueKeysWithValues: flux. transformer. parameters ( ) . flattened ( ) ) , loraWeight: loraWeight
201+ )
202+
203+ flux. transformer. update ( parameters: ModuleParameters . unflattened ( weights) )
204+ }
205+
144206 if loadConfiguration. quantize {
145207 quantize ( model: flux. clipEncoder, filter: { k, m in m is Linear } )
146208 quantize ( model: flux. t5Encoder, filter: { k, m in m is Linear } )
0 commit comments