@@ -143,11 +143,11 @@ open class FLUX {
143143 }
144144}
145145
146- public class Flux1Schnell : FLUX , TextToImageGenerator , @unchecked Sendable {
146+ public class Flux1Schnell : FLUX , TextToImageGenerator , ImageToImageGenerator , @unchecked Sendable {
147147 let clipTokenizer : CLIPTokenizer
148148 let t5Tokenizer : any Tokenizer
149- let transformer : MultiModalDiffusionTransformer
150- let vae : VAE
149+ public let transformer : MultiModalDiffusionTransformer
150+ public let vae : VAE
151151 let t5Encoder : T5Encoder
152152 let clipEncoder : CLIPEncoder
153153
@@ -186,27 +186,7 @@ public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
186186 return transformer
187187 }
188188
189- public func generateLatents( parameters: EvaluateParameters ) -> DenoiseIterator {
190- let latentsShape = [ 1 , ( parameters. height / 16 ) * ( parameters. width / 16 ) , 64 ]
191- let latents : MLXArray
192- if let seed = parameters. seed {
193- latents = MLXRandom . normal ( latentsShape, key: MLXRandom . key ( seed) )
194- } else {
195- latents = MLXRandom . normal ( latentsShape)
196- }
197- let ( promptEmbeddings, pooledPromptEmbeddings) = conditionText ( prompt: parameters. prompt)
198-
199- return DenoiseIterator (
200- steps: parameters. numInferenceSteps,
201- promptEmbeddings: promptEmbeddings,
202- pooledPromptEmbeddings: pooledPromptEmbeddings,
203- latents: latents,
204- evaluateParameters: parameters,
205- transformer: transformer
206- )
207- }
208-
209- func conditionText( prompt: String ) -> ( MLXArray , MLXArray ) {
189+ public func conditionText( prompt: String ) -> ( MLXArray , MLXArray ) {
210190 let t5Tokens = t5Tokenizer. encode ( text: prompt, addSpecialTokens: true )
211191 let paddedT5Tokens =
212192 Array ( t5Tokens. prefix ( 256 ) )
@@ -231,19 +211,19 @@ public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
231211 public func detachedDecoder( ) -> ImageDecoder {
232212 let autoencoder = self . vae
233213 func decode( xt: MLXArray ) -> MLXArray {
234- var x = autoencoder. decode ( latents : xt)
214+ var x = autoencoder. decode ( xt)
235215 x = clip ( x / 2 + 0.5 , min: 0 , max: 1 )
236216 return x
237217 }
238218 return decode ( xt: )
239219 }
240220}
241221
242- public class Flux1Dev : FLUX , TextToImageGenerator , @unchecked Sendable {
222+ public class Flux1Dev : FLUX , TextToImageGenerator , ImageToImageGenerator , @unchecked Sendable {
243223 let clipTokenizer : CLIPTokenizer
244224 let t5Tokenizer : any Tokenizer
245- let transformer : MultiModalDiffusionTransformer
246- let vae : VAE
225+ public let transformer : MultiModalDiffusionTransformer
226+ public let vae : VAE
247227 let t5Encoder : T5Encoder
248228 let clipEncoder : CLIPEncoder
249229
@@ -283,27 +263,7 @@ public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
283263 return transformer
284264 }
285265
286- public func generateLatents( parameters: EvaluateParameters ) -> DenoiseIterator {
287- let latentsShape = [ 1 , ( parameters. height / 16 ) * ( parameters. width / 16 ) , 64 ]
288- let latents : MLXArray
289- if let seed = parameters. seed {
290- latents = MLXRandom . normal ( latentsShape, key: MLXRandom . key ( seed) )
291- } else {
292- latents = MLXRandom . normal ( latentsShape)
293- }
294- let ( promptEmbeddings, pooledPromptEmbeddings) = conditionText ( prompt: parameters. prompt)
295-
296- return DenoiseIterator (
297- steps: parameters. numInferenceSteps,
298- promptEmbeddings: promptEmbeddings,
299- pooledPromptEmbeddings: pooledPromptEmbeddings,
300- latents: latents,
301- evaluateParameters: parameters,
302- transformer: transformer
303- )
304- }
305-
306- func conditionText( prompt: String ) -> ( MLXArray , MLXArray ) {
266+ public func conditionText( prompt: String ) -> ( MLXArray , MLXArray ) {
307267 let t5Tokens = t5Tokenizer. encode ( text: prompt, addSpecialTokens: true )
308268 let paddedT5Tokens =
309269 Array ( t5Tokens. prefix ( 512 ) )
@@ -328,7 +288,7 @@ public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
328288 public func detachedDecoder( ) -> ImageDecoder {
329289 let autoencoder = self . vae
330290 func decode( xt: MLXArray ) -> MLXArray {
331- var x = autoencoder. decode ( latents : xt)
291+ var x = autoencoder. decode ( xt)
332292 x = clip ( x / 2 + 0.5 , min: 0 , max: 1 )
333293 return x
334294 }
@@ -349,7 +309,87 @@ public protocol ImageGenerator {
349309}
350310
351311public protocol TextToImageGenerator : ImageGenerator , Sendable {
352- func generateLatents( parameters: EvaluateParameters ) -> DenoiseIterator
312+ var transformer : MultiModalDiffusionTransformer { get }
313+ func conditionText( prompt: String ) -> ( MLXArray , MLXArray )
314+ }
315+
316+ extension TextToImageGenerator {
317+ public func generateLatents( parameters: EvaluateParameters ) -> DenoiseIterator {
318+ let latentsShape = [ 1 , ( parameters. height / 16 ) * ( parameters. width / 16 ) , 64 ]
319+ let latents : MLXArray
320+ if let seed = parameters. seed {
321+ latents = MLXRandom . normal ( latentsShape, key: MLXRandom . key ( seed) )
322+ } else {
323+ latents = MLXRandom . normal ( latentsShape)
324+ }
325+ let ( promptEmbeddings, pooledPromptEmbeddings) = conditionText ( prompt: parameters. prompt)
326+ return DenoiseIterator (
327+ steps: parameters. numInferenceSteps,
328+ promptEmbeddings: promptEmbeddings,
329+ pooledPromptEmbeddings: pooledPromptEmbeddings,
330+ latents: latents,
331+ evaluateParameters: parameters,
332+ transformer: transformer
333+ )
334+ }
335+ }
336+
337+ /// Public interface for transforming a text prompt into an image.
338+ ///
339+ /// Steps:
340+ ///
341+ /// - ``generateLatents(image:parameters:strength:)``
342+ /// - evaluate each of the latents from the iterator
343+ /// - ``ImageGenerator/decode(xt:)`` or ``ImageGenerator/detachedDecoder()`` to convert the final latent into an image
344+ /// - use ``Image`` to save the image
345+ public protocol ImageToImageGenerator : ImageGenerator , Sendable {
346+ var transformer : MultiModalDiffusionTransformer { get }
347+ var vae : VAE { get }
348+ func conditionText( prompt: String ) -> ( MLXArray , MLXArray )
349+ func generateLatents( image: MLXArray , parameters: EvaluateParameters , strength: Float )
350+ -> DenoiseIterator
351+ }
352+
353+ extension ImageToImageGenerator {
354+ internal func packLatents( latents: MLXArray , height: Int , width: Int ) -> MLXArray {
355+ let reshaped = latents. reshaped ( 1 , height / 16 , 2 , width / 16 , 2 , 16 )
356+ let transposed = reshaped. transposed ( 0 , 1 , 3 , 5 , 2 , 4 )
357+ return transposed. reshaped ( 1 , ( height / 16 ) * ( width / 16 ) , 64 )
358+ }
359+
360+ public func generateLatents( image: MLXArray , parameters: EvaluateParameters , strength: Float )
361+ -> DenoiseIterator
362+ {
363+ if let seed = parameters. seed {
364+ MLXRandom . seed ( seed)
365+ }
366+ let noise = MLXRandom . normal ( [ 1 , ( parameters. height / 16 ) * ( parameters. width / 16 ) , 64 ] )
367+
368+ // Calculate the start step and number of steps based on strength
369+ let strength = max ( 0.0 , min ( 1.0 , strength) )
370+
371+ let startStep = max ( 1 , Int ( Float ( parameters. numInferenceSteps) * strength) )
372+
373+ var latents = vae. encode ( image [ . newAxis] )
374+
375+ latents = packLatents ( latents: latents, height: parameters. height, width: parameters. width)
376+
377+ let sigma = parameters. sigmas [ startStep]
378+
379+ latents = ( latents * ( 1 - sigma) + sigma * noise)
380+
381+ let ( promptEmbeddings, pooledPromptEmbeddings) = conditionText ( prompt: parameters. prompt)
382+
383+ return DenoiseIterator (
384+ startStep: startStep,
385+ steps: parameters. numInferenceSteps,
386+ promptEmbeddings: promptEmbeddings,
387+ pooledPromptEmbeddings: pooledPromptEmbeddings,
388+ latents: latents,
389+ evaluateParameters: parameters,
390+ transformer: transformer
391+ )
392+ }
353393}
354394
355395public typealias ImageDecoder = ( MLXArray ) -> MLXArray
@@ -364,14 +404,15 @@ public struct DenoiseIterator: Sequence, IteratorProtocol {
364404 let transformer : MultiModalDiffusionTransformer
365405
366406 init (
367- steps: Int , promptEmbeddings: MLXArray , pooledPromptEmbeddings: MLXArray , latents: MLXArray ,
407+ startStep: Int = 0 , steps: Int , promptEmbeddings: MLXArray , pooledPromptEmbeddings: MLXArray ,
408+ latents: MLXArray ,
368409 evaluateParameters: EvaluateParameters , transformer: MultiModalDiffusionTransformer
369410 ) {
370411 self . steps = steps
371412 self . promptEmbeddings = promptEmbeddings
372413 self . pooledPromptEmbeddings = pooledPromptEmbeddings
373414 self . latents = latents
374- self . i = 0
415+ self . i = startStep
375416 self . evaluateParameters = evaluateParameters
376417 self . transformer = transformer
377418 }
0 commit comments