@@ -612,7 +612,7 @@ def prepare_latents(self,
612612
613613 if image .shape [1 ] != num_channels_latents :
614614 image = self .vae .encode (image ).latent
615- image_latents = image * self .vae .config .scaling_factor
615+ image_latents = image * self .vae .config .scaling_factor * self . scheduler . config . sigma_data
616616 else :
617617 image_latents = image
618618 if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
@@ -632,8 +632,10 @@ def prepare_latents(self,
632632 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
633633 )
634634
635- noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
636- latents = self .scheduler .add_noise (image_latents , timestep , noise )
635+ # adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py
636+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype ) * self .scheduler .config .sigma_data
637+ # latents = self.scheduler.add_noise(image_latents, timestep, noise)
638+ latents = torch .cos (timestep ) * image_latents + torch .sin (timestep ) * noise
637639 return latents
638640
639641 @property
@@ -871,7 +873,8 @@ def __call__(
871873 latents ,
872874 )
873875
874- latents = latents * self .scheduler .config .sigma_data
876+ # I think this is redundant given the scaling in prepare_latents
877+ #latents = latents * self.scheduler.config.sigma_data
875878
876879 guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
877880 guidance = guidance .expand (latents .shape [0 ]).to (prompt_embeds .dtype )
0 commit comments