@@ -627,10 +627,6 @@ def __call__(
627627 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
628628 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
629629 max_sequence_length : int = 512 ,
630- sigmas = None ,
631- flip_schedule = False ,
632- even_timesteps = None ,
633- divide_timestep = True ,
634630 ):
635631 r"""
636632 Function invoked when calling the pipeline for generation.
@@ -767,6 +763,7 @@ def __call__(
767763 )
768764
769765 # 4.Prepare timesteps
766+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
770767 image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
771768 mu = calculate_shift (
772769 image_seq_len ,
@@ -783,9 +780,9 @@ def __call__(
783780 sigmas ,
784781 mu = mu ,
785782 )
786- if flip_schedule :
787- self .scheduler .sigmas = self .scheduler .sigmas .flip (0 )
788- self .scheduler .timesteps = self . scheduler . timesteps . flip ( 0 )
783+ self . scheduler . sigmas = self . scheduler . sigmas . flip ( 0 )
784+ self .scheduler .timesteps = self .scheduler .timesteps .flip (0 )
785+ self .scheduler .sigmas [ 0 ] += 1e-6
789786 print (f"self.scheduler.sigmas { self .scheduler .sigmas } " )
790787 print (f"self.scheduler.timesteps { self .scheduler .timesteps } " )
791788 timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
@@ -840,21 +837,14 @@ def __call__(
840837
841838 # 6. Denoising loop
842839 with self .progress_bar (total = num_inference_steps ) as progress_bar :
843- for i , t in enumerate ( timesteps ):
840+ for i in range ( self . _num_timesteps - 1 ):
844841 if self .interrupt :
845842 continue
846- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
847- if even_timesteps is None :
848- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
849- if divide_timestep :
850- timestep = timestep / 1000
851- else :
852- timestep = torch .tensor ([even_timesteps [i ]], device = latents .device , dtype = latents .dtype )
843+ timestep = torch .tensor ([self .scheduler .sigmas [i ]], device = latents .device , dtype = latents .dtype )
853844 # Unconditional vector field: $v_{t_i}(X_{t_i}) = -u(X_{t_i}, 1 - t_i, \Phi(\text{prompt}); \phi)$
854- timestep = 1 - timestep
855845 unconditional_vector_field = - self .transformer (
856846 hidden_states = latents ,
857- timestep = timestep if divide_timestep else timestep / 1000 ,
847+ timestep = timestep ,
858848 guidance = guidance ,
859849 pooled_projections = pooled_prompt_embeds ,
860850 encoder_hidden_states = prompt_embeds ,
@@ -865,31 +855,24 @@ def __call__(
865855 )[0 ]
866856
867857 # consider a time-varying controller guidance schedule ηt = η ∀t ≤ τ and 0 otherwise
868- control_guidance = controller_guidance if i < stopping_time else 0.0
869858 # Conditional vector field: $v_{t_i}(X_{t_i} | y_0) = \frac{y_0 - X_{t_i}}{1 - t_i}$
870- conditional_vector_field = (reference_image - latents ) / timestep
859+ t_i = i / self ._num_timesteps
860+ conditional_vector_field = (reference_image - latents ) / (1 - t_i )
871861 # Controlled vector field: $\hat{v}_{t_i}(X_{t_i}) = v_{t_i}(X_{t_i}) + \eta \left( v_{t_i}(X_{t_i} | y_0) - v_{t_i}(X_{t_i}) \right)$
872- controlled_vector_field = unconditional_vector_field + control_guidance * (conditional_vector_field - unconditional_vector_field )
862+ controlled_vector_field = unconditional_vector_field
863+ if i < stopping_time :
864+ controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field )
873865
874866 # compute the previous noisy sample x_t -> x_t-1
875867 latents_dtype = latents .dtype
876868 # Next state: $X_{t_{i+1}} = X_{t_i} + \hat{v}_{t_i}(X_{t_i}) \cdot (\sigma(t_{i+1}) - \sigma(t_i))$
877- latents = self .scheduler .step ( controlled_vector_field , t , latents , return_dict = False )[ 0 ]
869+ latents = latents + controlled_vector_field * ( self .scheduler .sigmas [ i + 1 ] - self . scheduler . sigmas [ i ])
878870
879871 if latents .dtype != latents_dtype :
880872 if torch .backends .mps .is_available ():
881873 # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
882874 latents = latents .to (latents_dtype )
883875
884- if callback_on_step_end is not None :
885- callback_kwargs = {}
886- for k in callback_on_step_end_tensor_inputs :
887- callback_kwargs [k ] = locals ()[k ]
888- callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
889-
890- latents = callback_outputs .pop ("latents" , latents )
891- prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
892-
893876 # call the callback, if provided
894877 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
895878 progress_bar .update ()
0 commit comments