1919)
2020
2121
22+ @serializable
2223class NoiseSchedule (ABC ):
23- """Noise schedule for diffusion models. We follow the notation from [1].
24+ r """Noise schedule for diffusion models. We follow the notation from [1].
2425
2526 The diffusion process is defined by a noise schedule, which determines how the noise level changes over time.
2627 We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be
@@ -39,8 +40,8 @@ class NoiseSchedule(ABC):
3940 def __init__ (self , name : str , variance_type : str ):
4041 self .name = name
4142 self .variance_type = variance_type # 'exploding' or 'preserving'
42- self ._log_snr_min = ops . convert_to_tensor ( - 15 ) # should be set in the subclasses
43- self ._log_snr_max = ops . convert_to_tensor ( 15 ) # should be set in the subclasses
43+ self ._log_snr_min = - 15 # should be set in the subclasses
44+ self ._log_snr_max = 15 # should be set in the subclasses
4445
4546 @property
4647 def scale_base_distribution (self ):
@@ -65,11 +66,11 @@ def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
6566
6667 @abstractmethod
6768 def derivative_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
68- """Compute \b eta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
69+ r """Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
6970 pass
7071
7172 def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = True ) -> tuple [Tensor , Tensor ]:
72- """Compute the drift and optionally the diffusion term for the reverse SDE.
73+ r """Compute the drift and optionally the diffusion term for the reverse SDE.
7374 Usually it can be derived from the derivative of the schedule:
7475 \beta(t) = d/dt log(1 + e^(-snr(t)))
7576 f(z, t) = -0.5 * \beta(t) * z
@@ -121,7 +122,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
121122 # 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
122123 return ops .ones_like (log_snr_t )
123124
125+ def get_config (self ):
126+ return dict (name = self .name , variance_type = self .variance_type )
127+
128+ @classmethod
129+ def from_config (cls , config , custom_objects = None ):
130+ return cls (** deserialize (config , custom_objects = custom_objects ))
131+
124132
133+ @serializable
125134class LinearNoiseSchedule (NoiseSchedule ):
126135 """Linear noise schedule for diffusion models.
127136
@@ -171,7 +180,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
171180 sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[1 ]
172181 return ops .square (g / sigma_t )
173182
183+ def get_config (self ):
184+ return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max )
174185
186+ @classmethod
187+ def from_config (cls , config , custom_objects = None ):
188+ return cls (** deserialize (config , custom_objects = custom_objects ))
189+
190+
191+ @serializable
175192class CosineNoiseSchedule (NoiseSchedule ):
176193 """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
177194 For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
@@ -181,7 +198,7 @@ class CosineNoiseSchedule(NoiseSchedule):
181198
182199 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 ):
183200 super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" )
184- self ._s_shift_cosine = ops . convert_to_tensor ( s_shift_cosine )
201+ self ._s_shift_cosine = s_shift_cosine
185202 self ._log_snr_min = min_log_snr
186203 self ._log_snr_max = max_log_snr
187204 self ._s_shift_cosine = s_shift_cosine
@@ -220,7 +237,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
220237 """
221238 return ops .sigmoid (- log_snr_t / 2 )
222239
240+ def get_config (self ):
241+ return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max , s_shift_cosine = self ._s_shift_cosine )
223242
243+ @classmethod
244+ def from_config (cls , config , custom_objects = None ):
245+ return cls (** deserialize (config , custom_objects = custom_objects ))
246+
247+
248+ @serializable
224249class EDMNoiseSchedule (NoiseSchedule ):
225250 """EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
226251
@@ -472,6 +497,7 @@ def velocity(
472497 ) -> Tensor :
473498 # calculate the current noise level and transform into correct shape
474499 log_snr_t = expand_right_as (self .noise_schedule .get_log_snr (t = time , training = training ), xz )
500+ log_snr_t = keras .ops .broadcast_to (log_snr_t , keras .ops .shape (xz )[:- 1 ] + (1 ,))
475501 alpha_t , sigma_t = self .noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t , training = training )
476502
477503 if conditions is None :
0 commit comments