1414
1515import  math 
1616from  dataclasses  import  dataclass 
17- from  typing  import  Optional , Tuple , Union 
17+ from  typing  import  List ,  Optional , Tuple , Union 
1818
1919import  torch 
2020
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
7777            Video](https://imagen.research.google/video/paper.pdf) paper). 
7878        rho (`float`, *optional*, defaults to 7.0): 
7979            The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. 
80+         final_sigmas_type (`str`, defaults to `"zero"`): 
81+             The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final 
82+             sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. 
8083    """ 
8184
8285    _compatibles  =  []
@@ -92,22 +95,32 @@ def __init__(
9295        num_train_timesteps : int  =  1000 ,
9396        prediction_type : str  =  "epsilon" ,
9497        rho : float  =  7.0 ,
98+         final_sigmas_type : str  =  "zero" ,  # can be "zero" or "sigma_min" 
9599    ):
96100        if  sigma_schedule  not  in "karras" , "exponential" ]:
97101            raise  ValueError (f"Wrong value for provided for `{ sigma_schedule = }  )
98102
99103        # setable values 
100104        self .num_inference_steps  =  None 
101105
102-         ramp  =  torch .linspace ( 0 ,  1 ,  num_train_timesteps ) 
106+         sigmas  =  torch .arange ( num_train_timesteps   +   1 )  /   num_train_timesteps 
103107        if  sigma_schedule  ==  "karras" :
104-             sigmas  =  self ._compute_karras_sigmas (ramp )
108+             sigmas  =  self ._compute_karras_sigmas (sigmas )
105109        elif  sigma_schedule  ==  "exponential" :
106-             sigmas  =  self ._compute_exponential_sigmas (ramp )
110+             sigmas  =  self ._compute_exponential_sigmas (sigmas )
107111
108112        self .timesteps  =  self .precondition_noise (sigmas )
109113
110-         self .sigmas  =  torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
114+         if  self .config .final_sigmas_type  ==  "sigma_min" :
115+             sigma_last  =  sigmas [- 1 ]
116+         elif  self .config .final_sigmas_type  ==  "zero" :
117+             sigma_last  =  0 
118+         else :
119+             raise  ValueError (
120+                 f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type }  
121+             )
122+ 
123+         self .sigmas  =  torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
111124
112125        self .is_scale_input_called  =  False 
113126
@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
197210        self .is_scale_input_called  =  True 
198211        return  sample 
199212
200-     def  set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] =  None ):
213+     def  set_timesteps (
214+         self ,
215+         num_inference_steps : int  =  None ,
216+         device : Union [str , torch .device ] =  None ,
217+         sigmas : Optional [Union [torch .Tensor , List [float ]]] =  None ,
218+     ):
201219        """ 
202220        Sets the discrete timesteps used for the diffusion chain (to be run before inference). 
203221
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206224                The number of diffusion steps used when generating samples with a pre-trained model. 
207225            device (`str` or `torch.device`, *optional*): 
208226                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 
227+             sigmas (`Union[torch.Tensor, List[float]]`, *optional*): 
228+                 Custom sigmas to use for the denoising process. If not defined, the default behavior when 
229+                 `num_inference_steps` is passed will be used. 
209230        """ 
210231        self .num_inference_steps  =  num_inference_steps 
211232
212-         ramp  =  torch .linspace (0 , 1 , self .num_inference_steps )
233+         if  sigmas  is  None :
234+             sigmas  =  torch .linspace (0 , 1 , self .num_inference_steps )
235+         elif  isinstance (sigmas , float ):
236+             sigmas  =  torch .tensor (sigmas , dtype = torch .float32 )
237+         else :
238+             sigmas  =  sigmas 
213239        if  self .config .sigma_schedule  ==  "karras" :
214-             sigmas  =  self ._compute_karras_sigmas (ramp )
240+             sigmas  =  self ._compute_karras_sigmas (sigmas )
215241        elif  self .config .sigma_schedule  ==  "exponential" :
216-             sigmas  =  self ._compute_exponential_sigmas (ramp )
242+             sigmas  =  self ._compute_exponential_sigmas (sigmas )
217243
218244        sigmas  =  sigmas .to (dtype = torch .float32 , device = device )
219245        self .timesteps  =  self .precondition_noise (sigmas )
220246
221-         self .sigmas  =  torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
247+         if  self .config .final_sigmas_type  ==  "sigma_min" :
248+             sigma_last  =  sigmas [- 1 ]
249+         elif  self .config .final_sigmas_type  ==  "zero" :
250+             sigma_last  =  0 
251+         else :
252+             raise  ValueError (
253+                 f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type }  
254+             )
255+ 
256+         self .sigmas  =  torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
222257        self ._step_index  =  None 
223258        self ._begin_index  =  None 
224259        self .sigmas  =  self .sigmas .to ("cpu" )  # to avoid too much CPU/GPU communication 
0 commit comments