1414
1515import math
1616from dataclasses import dataclass
17- from typing import Optional , Tuple , Union
17+ from typing import List , Optional , Tuple , Union
1818
1919import torch
2020
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
7777 Video](https://imagen.research.google/video/paper.pdf) paper).
7878 rho (`float`, *optional*, defaults to 7.0):
7979 The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80+ final_sigmas_type (`str`, defaults to `"zero"`):
81+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
8083 """
8184
8285 _compatibles = []
@@ -92,22 +95,32 @@ def __init__(
9295 num_train_timesteps : int = 1000 ,
9396 prediction_type : str = "epsilon" ,
9497 rho : float = 7.0 ,
98+ final_sigmas_type : str = "zero" , # can be "zero" or "sigma_min"
9599 ):
96100 if sigma_schedule not in ["karras" , "exponential" ]:
97101 raise ValueError (f"Wrong value for provided for `{ sigma_schedule = } `.`" )
98102
99103 # setable values
100104 self .num_inference_steps = None
101105
102- ramp = torch .linspace ( 0 , 1 , num_train_timesteps )
106+ sigmas = torch .arange ( num_train_timesteps + 1 ) / num_train_timesteps
103107 if sigma_schedule == "karras" :
104- sigmas = self ._compute_karras_sigmas (ramp )
108+ sigmas = self ._compute_karras_sigmas (sigmas )
105109 elif sigma_schedule == "exponential" :
106- sigmas = self ._compute_exponential_sigmas (ramp )
110+ sigmas = self ._compute_exponential_sigmas (sigmas )
107111
108112 self .timesteps = self .precondition_noise (sigmas )
109113
110- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
114+ if self .config .final_sigmas_type == "sigma_min" :
115+ sigma_last = sigmas [- 1 ]
116+ elif self .config .final_sigmas_type == "zero" :
117+ sigma_last = 0
118+ else :
119+ raise ValueError (
120+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
121+ )
122+
123+ self .sigmas = torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
111124
112125 self .is_scale_input_called = False
113126
@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
197210 self .is_scale_input_called = True
198211 return sample
199212
200- def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
213+ def set_timesteps (
214+ self ,
215+ num_inference_steps : int = None ,
216+ device : Union [str , torch .device ] = None ,
217+ sigmas : Optional [Union [torch .Tensor , List [float ]]] = None ,
218+ ):
201219 """
202220 Sets the discrete timesteps used for the diffusion chain (to be run before inference).
203221
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206224 The number of diffusion steps used when generating samples with a pre-trained model.
207225 device (`str` or `torch.device`, *optional*):
208226 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227+ sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
228+ Custom sigmas to use for the denoising process. If not defined, the default behavior when
229+ `num_inference_steps` is passed will be used.
209230 """
210231 self .num_inference_steps = num_inference_steps
211232
212- ramp = torch .linspace (0 , 1 , self .num_inference_steps )
233+ if sigmas is None :
234+ sigmas = torch .linspace (0 , 1 , self .num_inference_steps )
235+ elif isinstance (sigmas , float ):
236+ sigmas = torch .tensor (sigmas , dtype = torch .float32 )
237+ else :
238+ sigmas = sigmas
213239 if self .config .sigma_schedule == "karras" :
214- sigmas = self ._compute_karras_sigmas (ramp )
240+ sigmas = self ._compute_karras_sigmas (sigmas )
215241 elif self .config .sigma_schedule == "exponential" :
216- sigmas = self ._compute_exponential_sigmas (ramp )
242+ sigmas = self ._compute_exponential_sigmas (sigmas )
217243
218244 sigmas = sigmas .to (dtype = torch .float32 , device = device )
219245 self .timesteps = self .precondition_noise (sigmas )
220246
221- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
247+ if self .config .final_sigmas_type == "sigma_min" :
248+ sigma_last = sigmas [- 1 ]
249+ elif self .config .final_sigmas_type == "zero" :
250+ sigma_last = 0
251+ else :
252+ raise ValueError (
253+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got { self .config .final_sigmas_type } "
254+ )
255+
256+ self .sigmas = torch .cat ([sigmas , torch .full ((1 ,), fill_value = sigma_last , device = sigmas .device )])
222257 self ._step_index = None
223258 self ._begin_index = None
224259 self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
0 commit comments