@@ -283,12 +283,38 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
283
283
284
284
// Predict noise residuals from latent samples
285
285
// and current time step conditioned on hidden states
286
- var noise = try unet. predictNoise (
287
- latents: latentUnetInput,
288
- timeStep: t,
289
- hiddenStates: hiddenStates,
290
- additionalResiduals: additionalResiduals
291
- )
286
+ var noise : [ MLShapedArray < Float32 > ]
287
+ if unet. latentSampleShape [ 0 ] >= 2 || config. guidanceScale < 1.0 {
288
+ // One predict call from the uNet, using batching if needed
289
+ noise = try unet. predictNoise (
290
+ latents: latentUnetInput,
291
+ timeStep: t,
292
+ hiddenStates: hiddenStates,
293
+ additionalResiduals: additionalResiduals
294
+ )
295
+ } else {
296
+ // Serial predictions from uNet
297
+ var hidden0 = MLShapedArray < Float32 > ( converting: hiddenStates [ 0 ] )
298
+ hidden0 = MLShapedArray ( scalars: hidden0. scalars, shape: [ 1 ] + hidden0. shape)
299
+ let noise_pred_uncond = try unet. predictNoise (
300
+ latents: latents,
301
+ timeStep: t,
302
+ hiddenStates: hidden0,
303
+ additionalResiduals: additionalResiduals
304
+ )
305
+
306
+ var hidden1 = MLShapedArray < Float32 > ( converting: hiddenStates [ 1 ] )
307
+ hidden1 = MLShapedArray ( scalars: hidden1. scalars, shape: [ 1 ] + hidden1. shape)
308
+ let noise_pred_text = try unet. predictNoise (
309
+ latents: latents,
310
+ timeStep: t,
311
+ hiddenStates: hidden1,
312
+ additionalResiduals: additionalResiduals
313
+ )
314
+
315
+ noise = [ MLShapedArray < Float32 > ( concatenating: [ noise_pred_uncond [ 0 ] , noise_pred_text [ 0 ] ] ,
316
+ alongAxis: 0 ) ]
317
+ }
292
318
293
319
if config. guidanceScale >= 1.0 {
294
320
noise = performGuidance ( noise, config. guidanceScale)
0 commit comments