Skip to content

Commit 0ccaa63

Browse files
authored
add image to image generation (#4)
1 parent 36ce31e commit 0ccaa63

File tree

4 files changed

+136
-151
lines changed

4 files changed

+136
-151
lines changed

Sources/FLUX.swift

Lines changed: 94 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ open class FLUX {
143143
}
144144
}
145145

146-
public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
146+
public class Flux1Schnell: FLUX, TextToImageGenerator, ImageToImageGenerator, @unchecked Sendable {
147147
let clipTokenizer: CLIPTokenizer
148148
let t5Tokenizer: any Tokenizer
149-
let transformer: MultiModalDiffusionTransformer
150-
let vae: VAE
149+
public let transformer: MultiModalDiffusionTransformer
150+
public let vae: VAE
151151
let t5Encoder: T5Encoder
152152
let clipEncoder: CLIPEncoder
153153

@@ -186,27 +186,7 @@ public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
186186
return transformer
187187
}
188188

189-
public func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator {
190-
let latentsShape = [1, (parameters.height / 16) * (parameters.width / 16), 64]
191-
let latents: MLXArray
192-
if let seed = parameters.seed {
193-
latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(seed))
194-
} else {
195-
latents = MLXRandom.normal(latentsShape)
196-
}
197-
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
198-
199-
return DenoiseIterator(
200-
steps: parameters.numInferenceSteps,
201-
promptEmbeddings: promptEmbeddings,
202-
pooledPromptEmbeddings: pooledPromptEmbeddings,
203-
latents: latents,
204-
evaluateParameters: parameters,
205-
transformer: transformer
206-
)
207-
}
208-
209-
func conditionText(prompt: String) -> (MLXArray, MLXArray) {
189+
public func conditionText(prompt: String) -> (MLXArray, MLXArray) {
210190
let t5Tokens = t5Tokenizer.encode(text: prompt, addSpecialTokens: true)
211191
let paddedT5Tokens =
212192
Array(t5Tokens.prefix(256))
@@ -231,19 +211,19 @@ public class Flux1Schnell: FLUX, TextToImageGenerator, @unchecked Sendable {
231211
public func detachedDecoder() -> ImageDecoder {
232212
let autoencoder = self.vae
233213
func decode(xt: MLXArray) -> MLXArray {
234-
var x = autoencoder.decode(latents: xt)
214+
var x = autoencoder.decode(xt)
235215
x = clip(x / 2 + 0.5, min: 0, max: 1)
236216
return x
237217
}
238218
return decode(xt:)
239219
}
240220
}
241221

242-
public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
222+
public class Flux1Dev: FLUX, TextToImageGenerator, ImageToImageGenerator, @unchecked Sendable {
243223
let clipTokenizer: CLIPTokenizer
244224
let t5Tokenizer: any Tokenizer
245-
let transformer: MultiModalDiffusionTransformer
246-
let vae: VAE
225+
public let transformer: MultiModalDiffusionTransformer
226+
public let vae: VAE
247227
let t5Encoder: T5Encoder
248228
let clipEncoder: CLIPEncoder
249229

@@ -283,27 +263,7 @@ public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
283263
return transformer
284264
}
285265

286-
public func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator {
287-
let latentsShape = [1, (parameters.height / 16) * (parameters.width / 16), 64]
288-
let latents: MLXArray
289-
if let seed = parameters.seed {
290-
latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(seed))
291-
} else {
292-
latents = MLXRandom.normal(latentsShape)
293-
}
294-
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
295-
296-
return DenoiseIterator(
297-
steps: parameters.numInferenceSteps,
298-
promptEmbeddings: promptEmbeddings,
299-
pooledPromptEmbeddings: pooledPromptEmbeddings,
300-
latents: latents,
301-
evaluateParameters: parameters,
302-
transformer: transformer
303-
)
304-
}
305-
306-
func conditionText(prompt: String) -> (MLXArray, MLXArray) {
266+
public func conditionText(prompt: String) -> (MLXArray, MLXArray) {
307267
let t5Tokens = t5Tokenizer.encode(text: prompt, addSpecialTokens: true)
308268
let paddedT5Tokens =
309269
Array(t5Tokens.prefix(512))
@@ -328,7 +288,7 @@ public class Flux1Dev: FLUX, TextToImageGenerator, @unchecked Sendable {
328288
public func detachedDecoder() -> ImageDecoder {
329289
let autoencoder = self.vae
330290
func decode(xt: MLXArray) -> MLXArray {
331-
var x = autoencoder.decode(latents: xt)
291+
var x = autoencoder.decode(xt)
332292
x = clip(x / 2 + 0.5, min: 0, max: 1)
333293
return x
334294
}
@@ -349,7 +309,87 @@ public protocol ImageGenerator {
349309
}
350310

351311
public protocol TextToImageGenerator: ImageGenerator, Sendable {
352-
func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator
312+
var transformer: MultiModalDiffusionTransformer { get }
313+
func conditionText(prompt: String) -> (MLXArray, MLXArray)
314+
}
315+
316+
extension TextToImageGenerator {
317+
public func generateLatents(parameters: EvaluateParameters) -> DenoiseIterator {
318+
let latentsShape = [1, (parameters.height / 16) * (parameters.width / 16), 64]
319+
let latents: MLXArray
320+
if let seed = parameters.seed {
321+
latents = MLXRandom.normal(latentsShape, key: MLXRandom.key(seed))
322+
} else {
323+
latents = MLXRandom.normal(latentsShape)
324+
}
325+
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
326+
return DenoiseIterator(
327+
steps: parameters.numInferenceSteps,
328+
promptEmbeddings: promptEmbeddings,
329+
pooledPromptEmbeddings: pooledPromptEmbeddings,
330+
latents: latents,
331+
evaluateParameters: parameters,
332+
transformer: transformer
333+
)
334+
}
335+
}
336+
337+
/// Public interface for transforming a text prompt into an image.
338+
///
339+
/// Steps:
340+
///
341+
/// - ``generateLatents(image:parameters:strength:)``
342+
/// - evaluate each of the latents from the iterator
343+
/// - ``ImageGenerator/decode(xt:)`` or ``ImageGenerator/detachedDecoder()`` to convert the final latent into an image
344+
/// - use ``Image`` to save the image
345+
public protocol ImageToImageGenerator: ImageGenerator, Sendable {
346+
var transformer: MultiModalDiffusionTransformer { get }
347+
var vae: VAE { get }
348+
func conditionText(prompt: String) -> (MLXArray, MLXArray)
349+
func generateLatents(image: MLXArray, parameters: EvaluateParameters, strength: Float)
350+
-> DenoiseIterator
351+
}
352+
353+
extension ImageToImageGenerator {
354+
internal func packLatents(latents: MLXArray, height: Int, width: Int) -> MLXArray {
355+
let reshaped = latents.reshaped(1, height / 16, 2, width / 16, 2, 16)
356+
let transposed = reshaped.transposed(0, 1, 3, 5, 2, 4)
357+
return transposed.reshaped(1, (height / 16) * (width / 16), 64)
358+
}
359+
360+
public func generateLatents(image: MLXArray, parameters: EvaluateParameters, strength: Float)
361+
-> DenoiseIterator
362+
{
363+
if let seed = parameters.seed {
364+
MLXRandom.seed(seed)
365+
}
366+
let noise = MLXRandom.normal([1, (parameters.height / 16) * (parameters.width / 16), 64])
367+
368+
// Calculate the start step and number of steps based on strength
369+
let strength = max(0.0, min(1.0, strength))
370+
371+
let startStep = max(1, Int(Float(parameters.numInferenceSteps) * strength))
372+
373+
var latents = vae.encode(image[.newAxis])
374+
375+
latents = packLatents(latents: latents, height: parameters.height, width: parameters.width)
376+
377+
let sigma = parameters.sigmas[startStep]
378+
379+
latents = (latents * (1 - sigma) + sigma * noise)
380+
381+
let (promptEmbeddings, pooledPromptEmbeddings) = conditionText(prompt: parameters.prompt)
382+
383+
return DenoiseIterator(
384+
startStep: startStep,
385+
steps: parameters.numInferenceSteps,
386+
promptEmbeddings: promptEmbeddings,
387+
pooledPromptEmbeddings: pooledPromptEmbeddings,
388+
latents: latents,
389+
evaluateParameters: parameters,
390+
transformer: transformer
391+
)
392+
}
353393
}
354394

355395
public typealias ImageDecoder = (MLXArray) -> MLXArray
@@ -364,14 +404,15 @@ public struct DenoiseIterator: Sequence, IteratorProtocol {
364404
let transformer: MultiModalDiffusionTransformer
365405

366406
init(
367-
steps: Int, promptEmbeddings: MLXArray, pooledPromptEmbeddings: MLXArray, latents: MLXArray,
407+
startStep: Int = 0, steps: Int, promptEmbeddings: MLXArray, pooledPromptEmbeddings: MLXArray,
408+
latents: MLXArray,
368409
evaluateParameters: EvaluateParameters, transformer: MultiModalDiffusionTransformer
369410
) {
370411
self.steps = steps
371412
self.promptEmbeddings = promptEmbeddings
372413
self.pooledPromptEmbeddings = pooledPromptEmbeddings
373414
self.latents = latents
374-
self.i = 0
415+
self.i = startStep
375416
self.evaluateParameters = evaluateParameters
376417
self.transformer = transformer
377418
}

Sources/FluxConfiguration.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ public struct FluxConfiguration: Sendable {
146146
try factory(hub, self, configuration) as? TextToImageGenerator
147147
}
148148

149+
public func ImageToImageGenerator(hub: HubApi = HubApi(), configuration: LoadConfiguration)
150+
throws -> ImageToImageGenerator?
151+
{
152+
try factory(hub, self, configuration) as? ImageToImageGenerator
153+
}
154+
149155
public static let flux1Schnell = FluxConfiguration(
150156
id: "black-forest-labs/FLUX.1-schnell",
151157
files: [

0 commit comments

Comments
 (0)