@@ -799,7 +799,7 @@ def __call__(
799799 )
800800
801801 # 5. Prepare timesteps
802- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
802+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
803803 image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (int (width ) // self .vae_scale_factor // 2 )
804804 mu = calculate_shift (
805805 image_seq_len ,
@@ -816,9 +816,10 @@ def __call__(
816816 sigmas ,
817817 mu = mu ,
818818 )
819- start_timestep = int (start_timestep * num_inference_steps )
820- stop_timestep = min (int (stop_timestep * num_inference_steps ), num_inference_steps )
821- timesteps , sigmas , num_inference_steps = self .get_timesteps (num_inference_steps , strength )
819+ if do_rf_inversion :
820+ start_timestep = int (start_timestep * num_inference_steps )
821+ stop_timestep = min (int (stop_timestep * num_inference_steps ), num_inference_steps )
822+ timesteps , sigmas , num_inference_steps = self .get_timesteps (num_inference_steps , strength )
822823 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
823824 self ._num_timesteps = len (timesteps )
824825
@@ -833,10 +834,10 @@ def __call__(
833834 y_0 = image_latents .clone ()
834835 # 6. Denoising loop
835836 with self .progress_bar (total = num_inference_steps ) as progress_bar :
836-
837837 for i , t in enumerate (timesteps ):
838- t_i = 1 - t / 1000
839- dt = torch .tensor (1 / (len (timesteps ) - 1 ), device = device )
838+ if do_rf_inversion :
839+ t_i = 1 - t / 1000
840+ dt = torch .tensor (1 / (len (timesteps ) - 1 ), device = device )
840841
841842 if self .interrupt :
842843 continue
0 commit comments