Skip to content

Commit 2a52cd5

Browse files
committed
try scaling differently
1 parent caa0110 commit 2a52cd5

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def prepare_latents(self,
612612

613613
if image.shape[1] != num_channels_latents:
614614
image = self.vae.encode(image).latent
615-
image_latents = image * self.vae.config.scaling_factor
615+
image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data
616616
else:
617617
image_latents = image
618618
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
@@ -632,8 +632,10 @@ def prepare_latents(self,
632632
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
633633
)
634634

635-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
636-
latents = self.scheduler.add_noise(image_latents, timestep, noise)
635+
# adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py
636+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data
637+
# latents = self.scheduler.add_noise(image_latents, timestep, noise)
638+
latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise
637639
return latents
638640

639641
@property
@@ -871,7 +873,8 @@ def __call__(
871873
latents,
872874
)
873875

874-
latents = latents * self.scheduler.config.sigma_data
876+
# I think this is redundant given the scaling in prepare_latents
877+
#latents = latents * self.scheduler.config.sigma_data
875878

876879
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
877880
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)

0 commit comments

Comments
 (0)