@@ -276,7 +276,7 @@ class EDMNoiseSchedule(NoiseSchedule):
276276 [1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
277277 """
278278
279- def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
279+ def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80.0 ):
280280 super ().__init__ (name = "edm_noise_schedule" , variance_type = "exploding" )
281281 self .sigma_data = sigma_data
282282 # training settings
@@ -291,26 +291,25 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
291291 self ._log_snr_min = - 2 * ops .log (sigma_max )
292292 self ._log_snr_max = - 2 * ops .log (sigma_min )
293293 # t is not truncated for EDM by definition of the sampling schedule
294- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = False )
295- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = False )
294+ # training bounds are not so important, but should be set to avoid numerical issues
295+ self ._log_snr_min_training = self ._log_snr_min * 2 # one is never sampler during training
296+ self ._log_snr_max_training = self ._log_snr_max * 2 # 0 is almost surely never sampled during training
296297
297298 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
298299 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
299- t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
300300 if training :
301301 # SNR = -dist.icdf(t_trunc)
302302 loc = - 2 * self .p_mean
303303 scale = 2 * self .p_std
304- x = t_trunc
305- snr = - (loc + scale * ops .erfinv (2 * x - 1 ) * math .sqrt (2 ))
306- snr = keras .ops .clip (snr , x_min = self ._log_snr_min , x_max = self ._log_snr_max )
304+ snr = - (loc + scale * ops .erfinv (2 * t - 1 ) * math .sqrt (2 ))
305+ snr = keras .ops .clip (snr , x_min = self ._log_snr_min_training , x_max = self ._log_snr_max_training )
307306 else : # sampling
308307 snr = (
309308 - 2
310309 * self .rho
311310 * ops .log (
312311 self .sigma_max ** (1 / self .rho )
313- + (1 - t_trunc ) * (self .sigma_min ** (1 / self .rho ) - self .sigma_max ** (1 / self .rho ))
312+ + (1 - t ) * (self .sigma_min ** (1 / self .rho ) - self .sigma_max ** (1 / self .rho ))
314313 )
315314 )
316315 return snr
@@ -338,20 +337,18 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
338337 raise NotImplementedError ("Derivative of log SNR is not implemented for training mode." )
339338 # sampling mode
340339 t = self .get_t_from_log_snr (log_snr_t = log_snr_t , training = training )
341- t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
342340
343341 # SNR = -2*rho*log(s_max + (1 - x)*(s_min - s_max))
344342 s_max = self .sigma_max ** (1 / self .rho )
345343 s_min = self .sigma_min ** (1 / self .rho )
346- u = s_max + (1 - t_trunc ) * (s_min - s_max )
344+ u = s_max + (1 - t ) * (s_min - s_max )
347345 # d/dx snr = 2*rho*(s_min - s_max) / u
348346 dsnr_dx = 2 * self .rho * (s_min - s_max ) / u
349347
350348 # Using the chain rule on f(t) = log(1 + e^(-snr(t))):
351349 # f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
352- dsnr_dt = dsnr_dx * (self ._t_max - self ._t_min )
353350 factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
354- return - factor * dsnr_dt
351+ return - factor * dsnr_dx
355352
356353 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
357354 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
0 commit comments