@@ -69,8 +69,8 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
6969 r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
7070 pass
7171
72- def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = True ) -> tuple [Tensor , Tensor ]:
73- r"""Compute the drift and optionally the diffusion term for the reverse SDE.
72+ def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = False ) -> tuple [Tensor , Tensor ]:
73+ r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7474 Usually it can be derived from the derivative of the schedule:
7575 \beta(t) = d/dt log(1 + e^(-snr(t)))
7676 f(z, t) = -0.5 * \beta(t) * z
@@ -84,14 +84,14 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
8484 # Default implementation is to return the diffusion term only
8585 beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
8686 if x is None : # return g only
87- return ops . sqrt ( beta )
87+ return beta
8888 if self .variance_type == "preserving" :
8989 f = - 0.5 * beta * x
9090 elif self .variance_type == "exploding" :
9191 f = ops .zeros_like (beta )
9292 else :
9393 raise ValueError (f"Unknown variance type: { self .variance_type } " )
94- return f , ops . sqrt ( beta )
94+ return f , beta
9595
9696 def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
9797 """Get alpha and sigma for a given log signal-to-noise ratio (lambda).
@@ -144,7 +144,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
144144 self ._log_snr_min = min_log_snr
145145 self ._log_snr_max = max_log_snr
146146
147- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
147+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
148148 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
149149
150150 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
@@ -176,9 +176,9 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
176176 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
177177 Default is the likelihood weighting based on Song et al. (2021).
178178 """
179- g = self .get_drift_diffusion (log_snr_t = log_snr_t )
179+ g_squared = self .get_drift_diffusion (log_snr_t = log_snr_t )
180180 sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[1 ]
181- return ops .square (g / sigma_t )
181+ return g_squared / ops .square (sigma_t )
182182
183183 def get_config (self ):
184184 return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max )
@@ -203,7 +203,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
203203 self ._log_snr_max = max_log_snr
204204 self ._s_shift_cosine = s_shift_cosine
205205
206- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
206+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
207207 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
208208
209209 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
@@ -254,7 +254,6 @@ class EDMNoiseSchedule(NoiseSchedule):
254254
255255 def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
256256 super ().__init__ (name = "edm_noise_schedule" , variance_type = "exploding" )
257- super ().__init__ (name = "edm_noise_schedule" )
258257 self .sigma_data = sigma_data
259258 self .sigma_max = sigma_max
260259 self .sigma_min = sigma_min
@@ -265,7 +264,7 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
265264 # convert EDM parameters to signal-to-noise ratio formulation
266265 self ._log_snr_min = - 2 * ops .log (sigma_max )
267266 self ._log_snr_max = - 2 * ops .log (sigma_min )
268- self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
267+ self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_min , training = True )
269268 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
270269
271270 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
@@ -513,8 +512,8 @@ def velocity(
513512 score = (alpha_t * x_pred - xz ) / ops .square (sigma_t )
514513
515514 # compute velocity for the ODE depending on the noise schedule
516- f , g = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
517- out = f - 0.5 * ops . square ( g ) * score
515+ f , g_squared = self .noise_schedule .get_drift_diffusion (log_snr_t = log_snr_t , x = xz )
516+ out = f - 0.5 * g_squared * score
518517
519518 # todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
520519 return out
@@ -680,5 +679,5 @@ def compute_metrics(
680679 # apply sample weight
681680 loss = weighted_mean (loss , sample_weight )
682681
683- base_metrics = super ().compute_metrics (x , conditions , sample_weight , stage )
682+ base_metrics = super ().compute_metrics (x , conditions = conditions , sample_weight = sample_weight , stage = stage )
684683 return base_metrics | {"loss" : loss }
0 commit comments