Skip to content

Commit 6ac5cbb

Browse files
Use scale_noise directly and fix VAE decoding
1 parent 6ff1af8 commit 6ac5cbb

File tree

1 file changed

+36
-32
lines changed

1 file changed

+36
-32
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
5353
>>> init_image = init_image.resize((768, 512))
5454
55-
>>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
55+
>>> pipe = AuraFlowImg2ImgPipeline.from_pretrained("fal/AuraFlow-v0.3", torch_dtype=torch.float16)
5656
>>> pipe = pipe.to("cuda")
5757
>>> prompt = "A fantasy landscape, trending on artstation"
5858
>>> image = pipe(prompt=prompt, image=init_image, strength=0.75, num_inference_steps=50).images[0]
@@ -338,19 +338,20 @@ def prepare_latents(
338338
return latents
339339

340340
def get_timesteps(self, num_inference_steps, strength, device):
341-
# 1. Call set_timesteps with num_inference_steps
341+
# Set timesteps using the full range initially
342342
self.scheduler.set_timesteps(num_inference_steps, device=device)
343+
timesteps = self.scheduler.timesteps.to(device=device)
343344

344-
# 2. Calculate strength-based number of steps and offset
345-
init_timestep_count = min(int(num_inference_steps * strength), num_inference_steps)
346-
t_start = max(num_inference_steps - init_timestep_count, 0)
345+
if len(timesteps) != num_inference_steps:
346+
num_inference_steps = len(timesteps) # Adjust if scheduler changed num_steps
347347

348-
# 3. Get the timesteps *after* set_timesteps has been called (now has length num_inference_steps)
348+
# Get the original timestep using init_timestep
349+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
350+
351+
t_start = int(max(num_inference_steps - init_timestep, 0))
349352
timesteps = self.scheduler.timesteps[t_start:]
350353

351-
# 4. Return the correct slice and the number of actual steps
352-
num_actual_inference_steps = len(timesteps)
353-
return timesteps, num_actual_inference_steps
354+
return timesteps, num_inference_steps - t_start
354355

355356
def prepare_img2img_latents(
356357
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
@@ -385,11 +386,20 @@ def prepare_img2img_latents(
385386
f" Batch size must be divisible by the image batch size."
386387
)
387388

389+
# Temporarily move VAE to float32 for encoding
390+
vae_dtype = self.vae.dtype
391+
if vae_dtype != torch.float32:
392+
self.vae.to(dtype=torch.float32)
393+
388394
# encode the init image into latents and scale the latents
389395
# 1. Get VAE distribution parameters (on device)
390-
latent_dist = self.vae.encode(image).latent_dist
396+
latent_dist = self.vae.encode(image.to(dtype=torch.float32)).latent_dist
391397
mean, std = latent_dist.mean, latent_dist.std # Already on device
392398

399+
# Restore VAE dtype
400+
if vae_dtype != torch.float32:
401+
self.vae.to(dtype=vae_dtype)
402+
393403
# 2. Sample noise for each batch element individually if using multiple generators
394404
if isinstance(generator, list):
395405
sample = torch.cat(
@@ -416,7 +426,7 @@ def prepare_img2img_latents(
416426
latents = latents * self.vae.config.scaling_factor
417427

418428
# get the original timestep using init_timestep
419-
init_timestep = timestep
429+
init_timestep = timestep # Use the passed timestep directly
420430

421431
# add noise to latents using the timesteps
422432
# Handle noise generation with multiple generators if provided
@@ -439,20 +449,7 @@ def prepare_img2img_latents(
439449
latents.shape, generator=generator, device=generator_device, dtype=latents.dtype
440450
).to(latents.device)
441451

442-
# Ensure timestep tensor is on the same device
443-
t = init_timestep.to(latents.device)
444-
445-
# Normalize timestep to [0, 1] range (using scheduler's config)
446-
t = t / self.scheduler.config.num_train_timesteps
447-
448-
# Reshape t to match the dimensions needed for broadcasting
449-
required_dims = len(latents.shape)
450-
current_dims = len(t.shape)
451-
for _ in range(required_dims - current_dims):
452-
t = t.unsqueeze(-1)
453-
454-
# Interpolation: x_t = t * x_1 + (1 - t) * x_0
455-
latents = t * noise + (1 - t) * latents
452+
latents = self.scheduler.scale_noise(latents, init_timestep, noise)
456453

457454
return latents
458455

@@ -657,8 +654,10 @@ def __call__(
657654
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
658655

659656
# 5. Prepare timesteps
660-
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
661-
latent_timestep = timesteps[:1]
657+
timesteps, num_inference_steps = self.get_timesteps(
658+
num_inference_steps, strength, device
659+
)
660+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # Get the first timestep(s) for initial noise
662661

663662
# 6. Prepare latent variables
664663
latents = self.prepare_img2img_latents(
@@ -727,11 +726,11 @@ def __call__(
727726
if output_type == "latent":
728727
image = latents
729728
else:
730-
# make sure the VAE is in float32 mode, as it overflows in float16
731-
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
732-
if needs_upcasting:
733-
self.upcast_vae()
734-
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
729+
# Always upcast VAE to float32 for decoding
730+
vae_dtype = self.vae.dtype
731+
if vae_dtype != torch.float32:
732+
self.vae.to(dtype=torch.float32)
733+
latents = latents.to(dtype=torch.float32)
735734

736735
# Apply proper scaling factor and shift factor if available
737736
if (
@@ -746,6 +745,11 @@ def __call__(
746745
latents = latents / self.vae.config.scaling_factor
747746

748747
image = self.vae.decode(latents, return_dict=False)[0]
748+
749+
# Restore VAE dtype
750+
if vae_dtype != torch.float32:
751+
self.vae.to(dtype=vae_dtype)
752+
749753
image = self.image_processor.postprocess(image, output_type=output_type)
750754

751755
# Offload all models

0 commit comments

Comments
 (0)