@@ -627,6 +627,10 @@ def __call__(
627627 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
628628 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
629629 max_sequence_length : int = 512 ,
630+ sigmas = None ,
631+ flip_schedule = False ,
632+ even_timesteps = None ,
633+ divide_timestep = True ,
630634 ):
631635 r"""
632636 Function invoked when calling the pipeline for generation.
@@ -763,8 +767,6 @@ def __call__(
763767 )
764768
765769 # 4.Prepare timesteps
766- # Flux noise scheduler $\sigma : [0, 1] \to \mathbb{R}$
767- sigmas = np .linspace (0.0 , 1.0 , num_inference_steps )
768770 image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
769771 mu = calculate_shift (
770772 image_seq_len ,
@@ -781,6 +783,11 @@ def __call__(
781783 sigmas ,
782784 mu = mu ,
783785 )
786+ if flip_schedule :
787+ self .scheduler .sigmas = self .scheduler .sigmas .flip (0 )
788+ self .scheduler .timesteps = self .scheduler .timesteps .flip (0 )
789+ print (f"self.scheduler.sigmas { self .scheduler .sigmas } " )
790+ print (f"self.scheduler.timesteps { self .scheduler .timesteps } " )
784791 timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
785792
786793 if num_inference_steps < 1 :
@@ -837,13 +844,17 @@ def __call__(
837844 if self .interrupt :
838845 continue
839846 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
840- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
841- timestep = timestep / 1000
847+ if even_timesteps is None :
848+ timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
849+ if divide_timestep :
850+ timestep = timestep / 1000
851+ else :
852+ timestep = torch .tensor ([even_timesteps [i ]], device = latents .device , dtype = latents .dtype )
842853 # Unconditional vector field: $v_{t_i}(X_{t_i}) = -u(X_{t_i}, 1 - t_i, \Phi(\text{prompt}); \phi)$
843854 timestep = 1 - timestep
844855 unconditional_vector_field = - self .transformer (
845856 hidden_states = latents ,
846- timestep = timestep ,
857+ timestep = timestep if divide_timestep else timestep / 1000 ,
847858 guidance = guidance ,
848859 pooled_projections = pooled_prompt_embeds ,
849860 encoder_hidden_states = prompt_embeds ,
0 commit comments