@@ -626,10 +626,6 @@ def __call__(
626626 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
627627 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
628628 max_sequence_length : int = 512 ,
629- sigmas = None ,
630- flip_schedule = False ,
631- even_timesteps = None ,
632- divide_timestep = True
633629 ):
634630 r"""
635631 Function invoked when calling the pipeline for generation.
@@ -767,7 +763,7 @@ def __call__(
767763 )
768764
769765 # 4.Prepare timesteps
770- # Flux noise scheduler $\sigma : [ 0, 1] \to \mathbb{R}$
766+ sigmas = np . linspace ( 1. 0 , 1 / num_inference_steps , num_inference_steps )
771767 image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
772768 mu = calculate_shift (
773769 image_seq_len ,
@@ -784,9 +780,9 @@ def __call__(
784780 sigmas ,
785781 mu = mu ,
786782 )
787- if flip_schedule :
788- self .scheduler .sigmas = self .scheduler .sigmas .flip (0 )
789- self .scheduler .timesteps = self . scheduler . timesteps . flip ( 0 )
783+ self . scheduler . sigmas = self . scheduler . sigmas . flip ( 0 )
784+ self .scheduler .timesteps = self .scheduler .timesteps .flip (0 )
785+ self .scheduler .sigmas [ 0 ] += 1e-6
790786 print (f"self.scheduler.sigmas { self .scheduler .sigmas } " )
791787 print (f"self.scheduler.timesteps { self .scheduler .timesteps } " )
792788 timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
@@ -828,22 +824,18 @@ def __call__(
828824 y1 = randn_tensor (latents .shape , generator = generator , device = device , dtype = latents .dtype )
829825 # 6. Denoising loop
830826 with self .progress_bar (total = num_inference_steps ) as progress_bar :
831- for i , t in enumerate ( timesteps ):
827+ for i in range ( self . _num_timesteps - 1 ):
832828 # starting time s ∈ [0, 1] is defined as the time at which our controlled reverse ODE (15) is initialized.
833829 # The initial state Xs = y1−s is obtained by integrating the controlled forward ODE (8) from 0 → 1 − s.
834830 if i > self ._num_timesteps - starting_time :
835831 continue
836832 if self .interrupt :
837833 continue
838- if even_timesteps is None :
839- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
840- if divide_timestep :
841- timestep = timestep / 1000
842- else :
843- timestep = torch .tensor ([even_timesteps [i ]], device = latents .device , dtype = latents .dtype )
834+ timestep = torch .tensor ([self .scheduler .sigmas [i ]], device = latents .device , dtype = latents .dtype )
835+
844836 noise_pred = self .transformer (
845837 hidden_states = latents ,
846- timestep = timestep if divide_timestep else timestep / 1000 ,
838+ timestep = timestep ,
847839 guidance = guidance ,
848840 pooled_projections = pooled_prompt_embeds ,
849841 encoder_hidden_states = prompt_embeds ,
@@ -856,28 +848,20 @@ def __call__(
856848 # Unconditional vector field: $u_{t_i}(Y_{t_i}) = u(Y_{t_i}, t_i, \Phi(\text{""}); \phi)$
857849 unconditional_vector_field = noise_pred
858850 # Conditional vector field: $u_{t_i}(Y_{t_i} | y_1) = \frac{y_1 - Y_{t_i}}{1 - t_i}$
859- conditional_vector_field = (y1 - latents )/ (1 - timestep )
851+ t_i = i / self ._num_timesteps
852+ conditional_vector_field = (y1 - latents )/ (1 - t_i )
860853 # Controlled vector field: $\hat{u}_{t_i}(Y_{t_i}) = u_{t_i}(Y_{t_i}) + \gamma \left( u_{t_i}(Y_{t_i} | y_1) - u_{t_i}(Y_{t_i}) \right)$
861854 controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field )
862855
863856 latents_dtype = latents .dtype
864857 # Next state: $Y_{t_{i+1}} = Y_{t_i} + \hat{u}_{t_i}(Y_{t_i}) \cdot (\sigma(t_{i+1}) - \sigma(t_i))$
865- latents = self .scheduler .step ( controlled_vector_field , t , latents , return_dict = False )[ 0 ]
858+ latents = latents + controlled_vector_field * ( self .scheduler .sigmas [ i ] - self . scheduler . sigmas [ i + 1 ])
866859
867860 if latents .dtype != latents_dtype :
868861 if torch .backends .mps .is_available ():
869862 # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
870863 latents = latents .to (latents_dtype )
871864
872- if callback_on_step_end is not None :
873- callback_kwargs = {}
874- for k in callback_on_step_end_tensor_inputs :
875- callback_kwargs [k ] = locals ()[k ]
876- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
877-
878- latents = callback_outputs .pop ("latents" , latents )
879- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
880-
881865 # call the callback, if provided
882866 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
883867 progress_bar .update ()
0 commit comments