Skip to content

Commit 9375020

Browse files
Update src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py
Co-authored-by: Bagheera <[email protected]>
1 parent 1b7fb36 commit 9375020

File tree

1 file changed

+39
-88
lines changed

1 file changed

+39
-88
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py

Lines changed: 39 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -350,115 +350,66 @@ def get_timesteps(self, num_inference_steps, strength, device):
350350
return timesteps, num_inference_steps - t_start
351351

352352
def prepare_latents(
353-
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
353+
self,
354+
image,
355+
timestep,
356+
batch_size,
357+
num_images_per_prompt,
358+
dtype,
359+
device,
360+
generator=None,
354361
):
355362
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
356363
raise ValueError(
357-
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
364+
f"`image` must be `torch.Tensor`, `PIL.Image.Image` or list, got {type(image)}"
358365
)
359366

360-
# Check for latents_mean and latents_std in the VAE config
361-
latents_mean = latents_std = None
362-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
363-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
364-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
365-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
366-
367367
image = image.to(device=device, dtype=dtype)
368-
369368
batch_size = batch_size * num_images_per_prompt
370369

371370
if image.shape[1] == 4:
372-
latents = image
371+
latents_0 = image
373372
else:
374-
if isinstance(generator, list) and len(generator) != batch_size:
375-
raise ValueError(
376-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
377-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
378-
)
379-
380-
# Handle different batch size scenarios
381-
if image.shape[0] < batch_size:
382-
if batch_size % image.shape[0] == 0:
383-
# Duplicate the image to match the batch size
384-
additional_image_per_prompt = batch_size // image.shape[0]
385-
image = torch.cat([image] * additional_image_per_prompt, dim=0)
386-
else:
387-
raise ValueError(
388-
f"Cannot duplicate `image` of batch size {image.shape[0]} to {batch_size} text prompts."
389-
f" Batch size must be divisible by the image batch size."
390-
)
391-
392-
# Temporarily move VAE to float32 for encoding
393-
vae_dtype = self.vae.dtype
394-
if vae_dtype != torch.float32:
373+
# VAE ⇢ latents (ALWAYS on fp32 for numerical stability)
374+
orig_dtype = self.vae.dtype
375+
if orig_dtype != torch.float32:
395376
self.vae.to(dtype=torch.float32)
396377

397-
# encode the init image into latents and scale the latents
398-
# 1. Get VAE distribution parameters (on device)
399378
latent_dist = self.vae.encode(image.to(dtype=torch.float32)).latent_dist
400-
mean, std = latent_dist.mean, latent_dist.std # Already on device
379+
latents_0 = latent_dist.mean # ❶ deterministic!
401380

402-
# Restore VAE dtype
403-
if vae_dtype != torch.float32:
404-
self.vae.to(dtype=vae_dtype)
381+
if orig_dtype != torch.float32:
382+
self.vae.to(dtype=orig_dtype)
405383

406-
# 2. Sample noise for each batch element individually if using multiple generators
407-
if isinstance(generator, list):
408-
sample = torch.cat(
409-
[
410-
randn_tensor(
411-
(1, *mean.shape[1:]),
412-
generator=generator[i],
413-
device=mean.device,
414-
dtype=mean.dtype,
415-
)
416-
for i in range(batch_size)
417-
]
418-
)
419-
else:
420-
# Single generator - use its device if it has one
421-
sample = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype)
384+
# scale
385+
latents_0 = latents_0 * self.vae.config.scaling_factor
422386

423-
# Compute latents
424-
latents = mean + std * sample
425-
426-
# Apply standardization if VAE has mean and std defined in config
427-
if latents_mean is not None and latents_std is not None:
428-
latents_mean = latents_mean.to(device=device, dtype=dtype)
429-
latents_std = latents_std.to(device=device, dtype=dtype)
430-
latents = (latents - latents_mean) * self.vae.config.scaling_factor / latents_std
431-
else:
432-
# Scale latents
433-
latents = latents * self.vae.config.scaling_factor
434-
435-
# get the original timestep using init_timestep
436-
init_timestep = timestep # Use the passed timestep directly
437-
438-
# add noise to latents using the timesteps
439-
# Handle noise generation with multiple generators if provided
440-
if isinstance(generator, list):
441-
noise = torch.cat(
442-
[
443-
randn_tensor(
444-
(1, *latents.shape[1:]),
445-
generator=generator[i],
446-
device=latents.device,
447-
dtype=latents.dtype,
448-
)
449-
for i in range(batch_size)
450-
]
451-
)
452-
else:
453-
# Single generator - use its device if it has one
454-
noise = randn_tensor(
455-
latents.shape, generator=generator, device=latents.device, dtype=latents.dtype
387+
# replicate to match `batch_size`
388+
if latents_0.shape[0] != batch_size:
389+
if batch_size % latents_0.shape[0] != 0:
390+
raise ValueError(
391+
f"Cannot duplicate image batch of size {latents_0.shape[0]} "
392+
f"to effective batch size {batch_size}."
456393
)
394+
repeats = batch_size // latents_0.shape[0]
395+
latents_0 = latents_0.repeat(repeats, 1, 1, 1)
396+
397+
noise = randn_tensor(
398+
latents_0.shape,
399+
generator=generator,
400+
device=latents_0.device,
401+
dtype=latents_0.dtype,
402+
)
457403

458-
latents = self.scheduler.scale_noise(latents, init_timestep, noise)
404+
# make sure `timestep` is 1-D and matches batch
405+
if isinstance(timestep, (int, float)):
406+
timestep = torch.tensor([timestep], device=latents_0.device, dtype=latents_0.dtype)
407+
timestep = timestep.expand(latents_0.shape[0])
459408

409+
latents = self.scheduler.scale_noise(latents_0, timestep, noise)
460410
return latents
461411

412+
462413
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
463414
def upcast_vae(self):
464415
dtype = self.vae.dtype

0 commit comments

Comments
 (0)