@@ -75,7 +75,7 @@ def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
7575
7676 def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = False ) -> tuple [Tensor , Tensor ]:
7777 r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
78- Usually it can be derived from the derivative of the schedule:
78+ It can be derived from the derivative of the schedule:
7979 \beta(t) = d/dt log(1 + e^(-snr(t)))
8080 f(z, t) = -0.5 * \beta(t) * z
8181 g(t)^2 = \beta(t)
@@ -85,9 +85,8 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
8585
8686 For a variance exploding schedule, one should set f(z, t) = 0.
8787 """
88- # Default implementation is to return the diffusion term only
8988 beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
90- if x is None : # return g only
89+ if x is None : # return g^2 only
9190 return beta
9291 if self .variance_type == "preserving" :
9392 f = - 0.5 * beta * x
@@ -121,7 +120,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
121120 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
122121 Generally, weighting functions should be defined for a noise prediction loss.
123122 """
124- # sigmoid: ops.sigmoid(-log_snr_t / 2), based on Kingma et al. (2023)
123+ # sigmoid: ops.sigmoid(-log_snr_t + 2), based on Kingma et al. (2023)
125124 # min-snr with gamma = 5, based on Hang et al. (2023)
126125 # 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
127126 return ops .ones_like (log_snr_t )
@@ -291,9 +290,9 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
291290 self ._log_snr_min = - 2 * ops .log (sigma_max )
292291 self ._log_snr_max = - 2 * ops .log (sigma_min )
293292 # t is not truncated for EDM by definition of the sampling schedule
294- # training bounds are not so important, but should be set to avoid numerical issues
295- self ._log_snr_min_training = self ._log_snr_min * 2 # one is never sampler during training
296- self ._log_snr_max_training = self ._log_snr_max * 2 # 0 is almost surely never sampled during training
293+ # training bounds should be set to avoid numerical issues
294+ self ._log_snr_min_training = self ._log_snr_min - 1 # one is never sampler during training
295+ self ._log_snr_max_training = self ._log_snr_max + 1 # 0 is almost surely never sampled during training
297296
298297 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
299298 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -304,14 +303,9 @@ def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
304303 snr = - (loc + scale * ops .erfinv (2 * t - 1 ) * math .sqrt (2 ))
305304 snr = keras .ops .clip (snr , x_min = self ._log_snr_min_training , x_max = self ._log_snr_max_training )
306305 else : # sampling
307- snr = (
308- - 2
309- * self .rho
310- * ops .log (
311- self .sigma_max ** (1 / self .rho )
312- + (1 - t ) * (self .sigma_min ** (1 / self .rho ) - self .sigma_max ** (1 / self .rho ))
313- )
314- )
306+ sigma_min_rho = self .sigma_min ** (1 / self .rho )
307+ sigma_max_rho = self .sigma_max ** (1 / self .rho )
308+ snr = - 2 * self .rho * ops .log (sigma_max_rho + (1 - t ) * (sigma_min_rho - sigma_max_rho ))
315309 return snr
316310
317311 def get_t_from_log_snr (self , log_snr_t : Union [float , Tensor ], training : bool ) -> Tensor :
@@ -325,10 +319,9 @@ def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
325319 else : # sampling
326320 # SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
327321 # => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
328- t = 1 - (
329- (ops .exp (- log_snr_t / (2 * self .rho )) - self .sigma_max ** (1 / self .rho ))
330- / (self .sigma_min ** (1 / self .rho ) - self .sigma_max ** (1 / self .rho ))
331- )
322+ sigma_min_rho = self .sigma_min ** (1 / self .rho )
323+ sigma_max_rho = self .sigma_max ** (1 / self .rho )
324+ t = 1 - ((ops .exp (- log_snr_t / (2 * self .rho )) - sigma_max_rho ) / (sigma_min_rho - sigma_max_rho ))
332325 return t
333326
334327 def derivative_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
@@ -354,6 +347,13 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
354347 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
355348 return (ops .exp (- log_snr_t ) + ops .square (self .sigma_data )) / ops .square (self .sigma_data )
356349
350+ def get_config (self ):
351+ return dict (sigma_data = self .sigma_data , sigma_min = self .sigma_min , sigma_max = self .sigma_max )
352+
353+ @classmethod
354+ def from_config (cls , config , custom_objects = None ):
355+ return cls (** deserialize (config , custom_objects = custom_objects ))
356+
357357
358358@serializable
359359class DiffusionModel (InferenceNetwork ):
@@ -510,15 +510,15 @@ def convert_prediction_to_x(
510510 elif self .prediction_type == "noise" :
511511 # convert noise prediction into x
512512 x = (z - sigma_t * pred ) / alpha_t
513- elif self .prediction_type == "x" :
514- x = pred
515- elif self .prediction_type == "score" :
516- x = (z + sigma_t ** 2 * pred ) / alpha_t
517- else : # self.prediction_type == 'F': # EDM
513+ elif self .prediction_type == "F" : # EDM
518514 sigma_data = self .noise_schedule .sigma_data
519515 x1 = (sigma_data ** 2 * alpha_t ) / (ops .exp (- log_snr_t ) + sigma_data ** 2 )
520516 x2 = ops .exp (- log_snr_t / 2 ) * sigma_data / ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 )
521517 x = x1 * z + x2 * pred
518+ elif self .prediction_type == "x" :
519+ x = pred
520+ else : # "score"
521+ x = (z + sigma_t ** 2 * pred ) / alpha_t
522522
523523 if clip_x :
524524 x = keras .ops .clip (x , self ._clip_min , self ._clip_max )
@@ -606,7 +606,7 @@ def _forward(
606606 | kwargs
607607 )
608608 if integrate_kwargs ["method" ] == "euler_maruyama" :
609- raise ValueError ("Stoachastic methods are not supported for forward integration." )
609+ raise ValueError ("Stochastic methods are not supported for forward integration." )
610610
611611 if density :
612612
@@ -661,7 +661,7 @@ def _inverse(
661661 )
662662 if density :
663663 if integrate_kwargs ["method" ] == "euler_maruyama" :
664- raise ValueError ("Stoachastic methods are not supported for density computation." )
664+ raise ValueError ("Stochastic methods are not supported for density computation." )
665665
666666 def deltas (time , xz ):
667667 v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
0 commit comments