@@ -216,6 +216,7 @@ def __init__(
216216        rescale_betas_zero_snr : bool  =  False ,
217217        use_dynamic_shifting : bool  =  False ,
218218        time_shift_type : str  =  "exponential" ,
219+         shift_terminal : Optional [float ] =  None ,
219220    ):
220221        if  self .config .use_beta_sigmas  and  not  is_scipy_available ():
221222            raise  ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -235,6 +236,8 @@ def __init__(
235236            self .betas  =  betas_for_alpha_bar (num_train_timesteps )
236237        else :
237238            raise  NotImplementedError (f"{ beta_schedule } { self .__class__ }  )
239+         if  shift_terminal  is  not None  and  not  use_flow_sigmas :
240+             raise  ValueError ("`shift_terminal` is only supported when `use_flow_sigmas=True`." )
238241
239242        if  rescale_betas_zero_snr :
240243            self .betas  =  rescale_zero_terminal_snr (self .betas )
@@ -303,7 +306,12 @@ def set_begin_index(self, begin_index: int = 0):
303306        self ._begin_index  =  begin_index 
304307
305308    def  set_timesteps (
306-         self , num_inference_steps : int , device : Union [str , torch .device ] =  None , mu : Optional [float ] =  None 
309+         self ,
310+         num_inference_steps : Optional [int ] =  None ,
311+         device : Union [str , torch .device ] =  None ,
312+         mu : Optional [float ] =  None ,
313+         sigmas : Optional [List [float ]] =  None ,
314+         timesteps : Optional [List [float ]] =  None ,
307315    ):
308316        """ 
309317        Sets the discrete timesteps used for the diffusion chain (to be run before inference). 
@@ -314,10 +322,23 @@ def set_timesteps(
314322            device (`str` or `torch.device`, *optional*): 
315323                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 
316324        """ 
325+         if  self .config .use_dynamic_shifting  and  mu  is  None :
326+             raise  ValueError ("`mu` must be passed when `use_dynamic_shifting` is set to be `True`" )
327+ 
328+         if  sigmas  is  not None  or  timesteps  is  not None :
329+             if  not  self .config .use_flow_sigmas :
330+                 raise  ValueError (
331+                     "Passing `sigmas` or `timesteps` is only supported when `use_flow_sigmas=True`. " 
332+                     "Please set `use_flow_sigmas=True` during scheduler initialization." 
333+                 )
334+             num_inference_steps  =  len (sigmas ) if  sigmas  is  not None  else  len (timesteps )
335+         if  sigmas  is  not None  and  timesteps  is  not None :
336+             if  len (sigmas ) !=  len (timesteps ):
337+                 raise  ValueError ("`sigmas` and `timesteps` should have the same length" )
338+ 
339+         is_timesteps_provided  =  timesteps  is  not None 
340+ 
317341        # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 
318-         if  mu  is  not None :
319-             assert  self .config .use_dynamic_shifting  and  self .config .time_shift_type  ==  "exponential" 
320-             self .config .flow_shift  =  np .exp (mu )
321342        if  self .config .timestep_spacing  ==  "linspace" :
322343            timesteps  =  (
323344                np .linspace (0 , self .config .num_train_timesteps  -  1 , num_inference_steps  +  1 )
@@ -342,7 +363,8 @@ def set_timesteps(
342363                f"{ self .config .timestep_spacing }  
343364            )
344365
345-         sigmas  =  np .array (((1  -  self .alphas_cumprod ) /  self .alphas_cumprod ) **  0.5 )
366+         if  sigmas  is  None :
367+             sigmas  =  np .array (((1  -  self .alphas_cumprod ) /  self .alphas_cumprod ) **  0.5 )
346368        if  self .config .use_karras_sigmas :
347369            log_sigmas  =  np .log (sigmas )
348370            sigmas  =  np .flip (sigmas ).copy ()
@@ -386,10 +408,21 @@ def set_timesteps(
386408                )
387409            sigmas  =  np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
388410        elif  self .config .use_flow_sigmas :
389-             alphas  =  np .linspace (1 , 1  /  self .config .num_train_timesteps , num_inference_steps  +  1 )
390-             sigmas  =  1.0  -  alphas 
391-             sigmas  =  np .flip (self .config .flow_shift  *  sigmas  /  (1  +  (self .config .flow_shift  -  1 ) *  sigmas ))[:- 1 ].copy ()
392-             timesteps  =  (sigmas  *  self .config .num_train_timesteps ).copy ()
411+             if  sigmas  is  None :
412+                 sigmas  =  np .linspace (1 , 1  /  self .config .num_train_timesteps , num_inference_steps  +  1 )[:- 1 ]
413+             if  self .config .use_dynamic_shifting :
414+                 sigmas  =  self .time_shift (mu , 1.0 , sigmas )
415+             else :
416+                 sigmas  =  self .config .flow_shift  *  sigmas  /  (1  +  (self .config .flow_shift  -  1 ) *  sigmas )
417+             if  self .config .shift_terminal :
418+                 sigmas  =  self .stretch_shift_to_terminal (sigmas )
419+             eps  =  1e-6 
420+             if  np .fabs (sigmas [0 ] -  1 ) <  eps :
421+                 sigmas [0 ] -=  (
422+                     eps   # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update 
423+                 )
424+             if  not  is_timesteps_provided :
425+                 timesteps  =  (sigmas  *  self .config .num_train_timesteps ).copy ()
393426            if  self .config .final_sigmas_type  ==  "sigma_min" :
394427                sigma_last  =  sigmas [- 1 ]
395428            elif  self .config .final_sigmas_type  ==  "zero" :
@@ -429,6 +462,43 @@ def set_timesteps(
429462        self ._begin_index  =  None 
430463        self .sigmas  =  self .sigmas .to ("cpu" )  # to avoid too much CPU/GPU communication 
431464
465+     # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift 
466+     def  time_shift (self , mu : float , sigma : float , t : torch .Tensor ):
467+         if  self .config .time_shift_type  ==  "exponential" :
468+             return  self ._time_shift_exponential (mu , sigma , t )
469+         elif  self .config .time_shift_type  ==  "linear" :
470+             return  self ._time_shift_linear (mu , sigma , t )
471+ 
472+     # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal 
473+     def  stretch_shift_to_terminal (self , t : torch .Tensor ) ->  torch .Tensor :
474+         r""" 
475+         Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config 
476+         value. 
477+ 
478+         Reference: 
479+         https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 
480+ 
481+         Args: 
482+             t (`torch.Tensor`): 
483+                 A tensor of timesteps to be stretched and shifted. 
484+ 
485+         Returns: 
486+             `torch.Tensor`: 
487+                 A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. 
488+         """ 
489+         one_minus_z  =  1  -  t 
490+         scale_factor  =  one_minus_z [- 1 ] /  (1  -  self .config .shift_terminal )
491+         stretched_t  =  1  -  (one_minus_z  /  scale_factor )
492+         return  stretched_t 
493+ 
494+     # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential 
495+     def  _time_shift_exponential (self , mu , sigma , t ):
496+         return  math .exp (mu ) /  (math .exp (mu ) +  (1  /  t  -  1 ) **  sigma )
497+ 
498+     # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear 
499+     def  _time_shift_linear (self , mu , sigma , t ):
500+         return  mu  /  (mu  +  (1  /  t  -  1 ) **  sigma )
501+ 
432502    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 
433503    def  _threshold_sample (self , sample : torch .Tensor ) ->  torch .Tensor :
434504        """ 
0 commit comments