@@ -58,7 +58,7 @@ def scale_base_distribution(self):
5858 raise ValueError (f"Unknown variance type: { self .variance_type } " )
5959
6060 @abstractmethod
61- def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
61+ def get_log_snr (self , t : Union [ float , Tensor ] , training : bool ) -> Tensor :
6262 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
6363 pass
6464
@@ -145,10 +145,10 @@ def validate(self):
145145 raise ValueError ("t(0) must be finite." )
146146 if not ops .isfinite (self .get_t_from_log_snr (self ._log_snr_min , training = training )):
147147 raise ValueError ("t(1) must be finite." )
148- if not ops .isfinite (self .derivative_log_snr (self ._log_snr_max , training = training )):
149- raise ValueError ("dt/t log_snr(0) must be finite." )
150- if not ops .isfinite (self .derivative_log_snr (self ._log_snr_min , training = training )):
151- raise ValueError ("dt/t log_snr(1) must be finite." )
148+ if not ops .isfinite (self .derivative_log_snr (self ._log_snr_max , training = False )):
149+ raise ValueError ("dt/t log_snr(0) must be finite." )
150+ if not ops .isfinite (self .derivative_log_snr (self ._log_snr_min , training = False )):
151+ raise ValueError ("dt/t log_snr(1) must be finite." )
152152
153153
154154@serializable
@@ -168,7 +168,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
168168 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
169169 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
170170
171- def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
171+ def get_log_snr (self , t : Union [ float , Tensor ] , training : bool ) -> Tensor :
172172 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
173173 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
174174 # SNR = -log(exp(t^2) - 1)
@@ -228,7 +228,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
228228 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
229229 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
230230
231- def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
231+ def get_log_snr (self , t : Union [ float , Tensor ] , training : bool ) -> Tensor :
232232 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
233233 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
234234 # SNR = -2 * log(tan(pi*t/2))
@@ -289,7 +289,7 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
289289 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
290290 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
291291
292- def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
292+ def get_log_snr (self , t : Union [ float , Tensor ] , training : bool ) -> Tensor :
293293 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
294294 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
295295 if training :
0 commit comments