@@ -580,9 +580,7 @@ def prepare_latents(
580580 else :
581581 image_latents = torch .cat ([image_latents ], dim = 0 )
582582
583- noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
584- latents = self .scheduler .scale_noise (image_latents , timestep , noise )
585- latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
583+ latents = self ._pack_latents (image_latents , batch_size , num_channels_latents , height , width )
586584 return latents , latent_image_ids
587585
588586 @property
@@ -780,9 +778,6 @@ def __call__(
780778 sigmas ,
781779 mu = mu ,
782780 )
783- self .scheduler .sigmas = self .scheduler .sigmas .flip (0 )
784- self .scheduler .timesteps = self .scheduler .timesteps .flip (0 )
785- self .scheduler .sigmas [0 ] += 1e-6
786781 print (f"self.scheduler.sigmas { self .scheduler .sigmas } " )
787782 print (f"self.scheduler.timesteps { self .scheduler .timesteps } " )
788783 timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
@@ -866,7 +861,7 @@ def __call__(
866861 # compute the previous noisy sample x_t -> x_t-1
867862 latents_dtype = latents .dtype
868863 # Next state: $X_{t_{i+1}} = X_{t_i} + \hat{v}_{t_i}(X_{t_i}) \cdot (\sigma(t_{i+1}) - \sigma(t_i))$
869- latents = latents + controlled_vector_field * (self . scheduler . sigmas [i + 1 ] - self . scheduler . sigmas [i ])
864+ latents = latents + controlled_vector_field * (sigmas [i ] - sigmas [i + 1 ])
870865
871866 if latents .dtype != latents_dtype :
872867 if torch .backends .mps .is_available ():
0 commit comments