Skip to content

Commit 6775c2c

Browse files
Allow not using classifier free guidance for inference in Swift code
1 parent 1c194da commit 6775c2c

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,24 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
210210
progressHandler: (Progress) -> Bool = { _ in true }
211211
) throws -> [CGImage?] {
212212

213-
// Encode the input prompt and negative prompt
214-
let promptEmbedding = try textEncoder.encode(config.prompt)
215-
let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
213+
// Encode the input prompt
214+
var promptEmbedding = try textEncoder.encode(config.prompt)
215+
216+
if config.guidanceScale >= 1.0 {
217+
// Convert to Unet hidden state representation
218+
// Concatenate the prompt and negative prompt embeddings
219+
let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
220+
promptEmbedding = MLShapedArray<Float32>(
221+
concatenating: [negativePromptEmbedding, promptEmbedding],
222+
alongAxis: 0
223+
)
224+
}
216225

217226
if reduceMemory {
218227
textEncoder.unloadResources()
219228
}
220229

221-
// Convert to Unet hidden state representation
222-
// Concatenate the prompt and negative prompt embeddings
223-
let concatEmbedding = MLShapedArray<Float32>(
224-
concatenating: [negativePromptEmbedding, promptEmbedding],
225-
alongAxis: 0
226-
)
227-
228-
let hiddenStates = useMultilingualTextEncoder ? concatEmbedding : toHiddenStates(concatEmbedding)
230+
let hiddenStates = useMultilingualTextEncoder ? promptEmbedding : toHiddenStates(promptEmbedding)
229231

230232
/// Setup schedulers
231233
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
@@ -262,8 +264,13 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
262264

263265
// Expand the latents for classifier-free guidance
264266
// and input to the Unet noise prediction model
265-
let latentUnetInput = latents.map {
266-
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
267+
let latentUnetInput: [MLShapedArray<Float32>]
268+
if config.guidanceScale >= 1.0 {
269+
latentUnetInput = latents.map {
270+
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
271+
}
272+
} else {
273+
latentUnetInput = latents
267274
}
268275

269276
// Before Unet, execute controlNet and add the output into Unet inputs
@@ -283,7 +290,9 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
283290
additionalResiduals: additionalResiduals
284291
)
285292

286-
noise = performGuidance(noise, config.guidanceScale)
293+
if config.guidanceScale >= 1.0 {
294+
noise = performGuidance(noise, config.guidanceScale)
295+
}
287296

288297
// Have the scheduler compute the previous (t-1) latent
289298
// sample given the predicted noise and current sample

swift/StableDiffusion/pipeline/Unet.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@ public struct Unet: ResourceManaging {
9595
) throws -> [MLShapedArray<Float32>] {
9696

9797
// Match time step batch dimension to the model / latent samples
98-
let t = MLShapedArray<Float32>(scalars:[Float(timeStep), Float(timeStep)],shape:[2])
98+
let t: MLShapedArray<Float32>
99+
if hiddenStates.shape[0] == 2 {
100+
t = MLShapedArray(scalars: [Float(timeStep), Float(timeStep)], shape: [2])
101+
} else {
102+
t = MLShapedArray(scalars: [Float(timeStep)], shape: [1])
103+
}
99104

100105
// Form batch input to model
101106
let inputs = try latents.enumerated().map {

0 commit comments

Comments
 (0)