@@ -54,7 +54,7 @@ def scale_base_distribution(self):
5454 return 1.0
5555 elif self .variance_type == "exploding" :
5656 # e.g., EDM is a variance exploding schedule
57- return ops .exp (- self ._log_snr_min )
57+ return ops .sqrt ( ops . exp (- self ._log_snr_min ) )
5858 else :
5959 raise ValueError (f"Unknown variance type: { self .variance_type } " )
6060
@@ -279,17 +279,20 @@ class EDMNoiseSchedule(NoiseSchedule):
279279 def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
280280 super ().__init__ (name = "edm_noise_schedule" , variance_type = "exploding" )
281281 self .sigma_data = sigma_data
282- self .sigma_max = sigma_max
283- self .sigma_min = sigma_min
282+ # training settings
284283 self .p_mean = - 1.2
285284 self .p_std = 1.2
285+ # sampling settings
286+ self .sigma_max = sigma_max
287+ self .sigma_min = sigma_min
286288 self .rho = 7
287289
288290 # convert EDM parameters to signal-to-noise ratio formulation
289291 self ._log_snr_min = - 2 * ops .log (sigma_max )
290292 self ._log_snr_max = - 2 * ops .log (sigma_min )
291- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
292- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
293+ # t is not truncated for EDM by definition of the sampling schedule
294+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = False )
295+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = False )
293296
294297 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
295298 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
0 commit comments