@@ -146,14 +146,15 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
146146 self ._log_snr_min = min_log_snr
147147 self ._log_snr_max = max_log_snr
148148
149- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
150- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
149+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
150+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
151151
152152 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
153153 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
154154 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
155155 # SNR = -log(exp(t^2) - 1)
156- return - ops .log (ops .exp (ops .square (t_trunc )) - 1 )
156+ # equivalent, but more stable: -t^2 - log(1 - exp(-t^2))
157+ return - ops .square (t_trunc ) - ops .log (1 - ops .exp (- ops .square (t_trunc )))
157158
158159 def get_t_from_log_snr (self , log_snr_t : Union [float , Tensor ], training : bool ) -> Tensor :
159160 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
@@ -205,8 +206,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
205206 self ._log_snr_max = max_log_snr
206207 self ._s_shift_cosine = s_shift_cosine
207208
208- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
209- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
209+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
210+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
210211
211212 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
212213 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -266,8 +267,8 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
266267 # convert EDM parameters to signal-to-noise ratio formulation
267268 self ._log_snr_min = - 2 * ops .log (sigma_max )
268269 self ._log_snr_max = - 2 * ops .log (sigma_min )
269- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
270- self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
270+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
271+ self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
271272
272273 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
273274 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
0 commit comments