@@ -37,11 +37,22 @@ class NoiseSchedule(ABC):
3737 Augmentation: Kingma et al. (2023)
3838 """
3939
40- def __init__ (self , name : str ):
40+ def __init__ (self , name : str , variance_type : str ):
4141 self .name = name
42-
43- # for variance preserving schedules
44- self .scale_base_distribution = 1.0
42+ self .variance_type = variance_type # 'exploding' or 'preserving'
43+ self ._log_snr_min = ops .convert_to_tensor (- 15 ) # should be set in the subclasses
44+ self ._log_snr_max = ops .convert_to_tensor (15 ) # should be set in the subclasses
45+
46+ @property
47+ def scale_base_distribution (self ):
48+ """Get the scale of the base distribution."""
49+ if self .variance_type == "preserving" :
50+ return 1.0
51+ elif self .variance_type == "exploding" :
52+ # e.g., EDM is a variance exploding schedule
53+ return ops .exp (- self ._log_snr_min )
54+ else :
55+ raise ValueError (f"Unknown variance type: { self .variance_type } " )
4556
4657 @abstractmethod
4758 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
@@ -74,17 +85,32 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
7485 beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
7586 if x is None : # return g only
7687 return ops .sqrt (beta )
77- f = - 0.5 * beta * x
88+ if self .variance_type == "preserving" :
89+ f = - 0.5 * beta * x
90+ elif self .variance_type == "exploding" :
91+ f = ops .zeros_like (beta )
92+ else :
93+ raise ValueError (f"Unknown variance type: { self .variance_type } " )
7894 return f , ops .sqrt (beta )
7995
8096 def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
8197 """Get alpha and sigma for a given log signal-to-noise ratio (lambda).
8298
83- Default is a variance preserving schedule.
99+ Default is a variance preserving schedule:
100+ alpha(t) = sqrt(sigmoid(log_snr_t))
101+ sigma(t) = sqrt(sigmoid(-log_snr_t))
84102 For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
85103 """
86- alpha_t = keras .ops .sqrt (keras .ops .sigmoid (log_snr_t ))
87- sigma_t = keras .ops .sqrt (keras .ops .sigmoid (- log_snr_t ))
104+ if self .variance_type == "preserving" :
105+ # variance preserving schedule
106+ alpha_t = keras .ops .sqrt (keras .ops .sigmoid (log_snr_t ))
107+ sigma_t = keras .ops .sqrt (keras .ops .sigmoid (- log_snr_t ))
108+ elif self .variance_type == "exploding" :
109+ # variance exploding schedule
110+ alpha_t = ops .ones_like (log_snr_t )
111+ sigma_t = ops .sqrt (ops .exp (- log_snr_t ))
112+ else :
113+ raise ValueError (f"Unknown variance type: { self .variance_type } " )
88114 return alpha_t , sigma_t
89115
90116 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
@@ -106,7 +132,7 @@ class LinearNoiseSchedule(NoiseSchedule):
106132 """
107133
108134 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
109- super ().__init__ (name = "linear_noise_schedule" )
135+ super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" )
110136 self ._log_snr_min = ops .convert_to_tensor (min_log_snr )
111137 self ._log_snr_max = ops .convert_to_tensor (max_log_snr )
112138
@@ -155,7 +181,7 @@ class CosineNoiseSchedule(NoiseSchedule):
155181 """
156182
157183 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 ):
158- super ().__init__ (name = "cosine_noise_schedule" )
184+ super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" )
159185 self ._log_snr_min = ops .convert_to_tensor (min_log_snr )
160186 self ._log_snr_max = ops .convert_to_tensor (max_log_snr )
161187 self ._s_shift_cosine = ops .convert_to_tensor (s_shift_cosine )
@@ -202,7 +228,7 @@ class EDMNoiseSchedule(NoiseSchedule):
202228 """
203229
204230 def __init__ (self , sigma_data : float = 0.5 , sigma_min : float = 0.002 , sigma_max : float = 80 ):
205- super ().__init__ (name = "edm_noise_schedule" )
231+ super ().__init__ (name = "edm_noise_schedule" , variance_type = "exploding" )
206232 self .sigma_data = ops .convert_to_tensor (sigma_data )
207233 self .sigma_max = ops .convert_to_tensor (sigma_max )
208234 self .sigma_min = ops .convert_to_tensor (sigma_min )
@@ -216,9 +242,6 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
216242 self ._t_min = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
217243 self ._t_max = self .get_t_from_log_snr (log_snr_t = self ._log_snr_max , training = True )
218244
219- # EDM is a variance exploding schedule
220- self .scale_base_distribution = ops .exp (- self ._log_snr_min )
221-
222245 def get_log_snr (self , t : Tensor , training : bool ) -> Tensor :
223246 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
224247 t_trunc = self ._t_min + (self ._t_max - self ._t_min ) * t
@@ -278,28 +301,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
278301 factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
279302 return - factor * dsnr_dt
280303
281- def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = True ) -> tuple [Tensor , Tensor ]:
282- """Compute the drift and optionally the diffusion term for the variance exploding reverse SDE.
283- \b eta(t) = d/dt log(1 + e^(-snr(t)))
284- f(z, t) = 0
285- g(t)^2 = \b eta(t)
286-
287- SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
288- ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
289- """
290- # Default implementation is to return the diffusion term only
291- beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
292- if x is None : # return g only
293- return ops .sqrt (beta )
294- f = ops .zeros_like (beta ) # variance exploding schedule
295- return f , ops .sqrt (beta )
296-
297- def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
298- """Get alpha and sigma for a given log signal-to-noise ratio (lambda) for a variance exploding schedule."""
299- alpha_t = ops .ones_like (log_snr_t )
300- sigma_t = ops .sqrt (ops .exp (- log_snr_t ))
301- return alpha_t , sigma_t
302-
303304 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
304305 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
305306 return ops .exp (- log_snr_t ) + 0.5 ** 2
0 commit comments