@@ -158,16 +158,34 @@ def validate(self):
158158@serializable ("bayesflow.experimental" )
159159class CosineNoiseSchedule (NoiseSchedule ):
160160 """Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
161- For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
162161
163162 [1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
164163 """
165164
166165 def __init__ (
167- self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 , weighting : str = "sigmoid"
166+ self ,
167+ min_log_snr : float = - 15 ,
168+ max_log_snr : float = 15 ,
169+ shift : float = 0.0 ,
170+ weighting : Literal ["sigmoid" , "likelihood_weighting" ] = "sigmoid" ,
168171 ):
172+ """
173+ Initialize the cosine noise schedule.
174+
175+ Parameters
176+ ----------
177+ min_log_snr : float, optional
178+ The minimum log signal-to-noise ratio (lambda). Default is -15.
179+ max_log_snr : float, optional
180+ The maximum log signal-to-noise ratio (lambda). Default is 15.
181+ shift : float, optional
182+ Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
183+ For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
184+ weighting : Literal["sigmoid", "likelihood_weighting"], optional
185+ The type of weighting function to use for the noise schedule. Default is "sigmoid".
186+ """
169187 super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" , weighting = weighting )
170- self ._s_shift_cosine = s_shift_cosine
188+ self ._shift = shift
171189 self .log_snr_min = min_log_snr
172190 self .log_snr_max = max_log_snr
173191
@@ -180,12 +198,12 @@ def _truncated_t(self, t: Tensor) -> Tensor:
180198 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
181199 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
182200 t_trunc = self ._truncated_t (t )
183- return - 2 * ops .log (ops .tan (math .pi * t_trunc * 0.5 )) + 2 * self ._s_shift_cosine
201+ return - 2 * ops .log (ops .tan (math .pi * t_trunc * 0.5 )) + 2 * self ._shift
184202
185203 def get_t_from_log_snr (self , log_snr_t : Union [Tensor , float ], training : bool ) -> Tensor :
186204 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
187205 # SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
188- return 2 / math .pi * ops .arctan (ops .exp ((2 * self ._s_shift_cosine - log_snr_t ) * 0.5 ))
206+ return 2 / math .pi * ops .arctan (ops .exp ((2 * self ._shift - log_snr_t ) * 0.5 ))
189207
190208 def derivative_log_snr (self , log_snr_t : Tensor , training : bool ) -> Tensor :
191209 """Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
@@ -202,7 +220,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
202220 return - factor * dsnr_dt
203221
204222 def get_config (self ):
205- return dict (min_log_snr = self .log_snr_min , max_log_snr = self .log_snr_max , s_shift_cosine = self ._s_shift_cosine )
223+ return dict (min_log_snr = self .log_snr_min , max_log_snr = self .log_snr_max , shift = self ._shift )
206224
207225 @classmethod
208226 def from_config (cls , config , custom_objects = None ):
0 commit comments