Skip to content

Commit bdbe459

Browse files
committed
In Swift, allow serial unet predictions
1 parent bef26ae commit bdbe459

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,38 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
283283

284284
// Predict noise residuals from latent samples
285285
// and current time step conditioned on hidden states
286-
var noise = try unet.predictNoise(
287-
latents: latentUnetInput,
288-
timeStep: t,
289-
hiddenStates: hiddenStates,
290-
additionalResiduals: additionalResiduals
291-
)
286+
var noise : [MLShapedArray<Float32>]
287+
if unet.latentSampleShape[0] >= 2 || config.guidanceScale < 1.0 {
288+
// One predict call from the uNet, using batching if needed
289+
noise = try unet.predictNoise(
290+
latents: latentUnetInput,
291+
timeStep: t,
292+
hiddenStates: hiddenStates,
293+
additionalResiduals: additionalResiduals
294+
)
295+
} else {
296+
// Serial predictions from uNet
297+
var hidden0 = MLShapedArray<Float32>(converting: hiddenStates[0])
298+
hidden0 = MLShapedArray(scalars: hidden0.scalars, shape: [1]+hidden0.shape)
299+
let noise_pred_uncond = try unet.predictNoise(
300+
latents: latents,
301+
timeStep: t,
302+
hiddenStates: hidden0,
303+
additionalResiduals: additionalResiduals
304+
)
305+
306+
var hidden1 = MLShapedArray<Float32>(converting: hiddenStates[1])
307+
hidden1 = MLShapedArray(scalars: hidden1.scalars, shape: [1]+hidden1.shape)
308+
let noise_pred_text = try unet.predictNoise(
309+
latents: latents,
310+
timeStep: t,
311+
hiddenStates: hidden1,
312+
additionalResiduals: additionalResiduals
313+
)
314+
315+
noise = [MLShapedArray<Float32>(concatenating: [noise_pred_uncond[0], noise_pred_text[0]],
316+
alongAxis: 0)]
317+
}
292318

293319
if config.guidanceScale >= 1.0 {
294320
noise = performGuidance(noise, config.guidanceScale)

0 commit comments

Comments
 (0)