Skip to content

Commit 0e5a48c

Browse files
committed
fix Literal
1 parent 0781032 commit 0e5a48c

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

bayesflow/experimental/noise_schedules.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,34 @@ def validate(self):
158158
@serializable("bayesflow.experimental")
159159
class CosineNoiseSchedule(NoiseSchedule):
160160
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
161-
For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
162161
163162
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
164163
"""
165164

166165
def __init__(
167-
self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0, weighting: str = "sigmoid"
166+
self,
167+
min_log_snr: float = -15,
168+
max_log_snr: float = 15,
169+
shift: float = 0.0,
170+
weighting: Literal["sigmoid", "likelihood_weighting"] = "sigmoid",
168171
):
172+
"""
173+
Initialize the cosine noise schedule.
174+
175+
Parameters
176+
----------
177+
min_log_snr : float, optional
178+
The minimum log signal-to-noise ratio (lambda). Default is -15.
179+
max_log_snr : float, optional
180+
The maximum log signal-to-noise ratio (lambda). Default is 15.
181+
shift : float, optional
182+
Shift the log signal-to-noise ratio (lambda) by this amount. Default is 0.0.
183+
For images, use shift = log(base_resolution / d), where d is the used resolution of the image.
184+
weighting : Literal["sigmoid", "likelihood_weighting"], optional
185+
The type of weighting function to use for the noise schedule. Default is "sigmoid".
186+
"""
169187
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
170-
self._s_shift_cosine = s_shift_cosine
188+
self._shift = shift
171189
self.log_snr_min = min_log_snr
172190
self.log_snr_max = max_log_snr
173191

@@ -180,12 +198,12 @@ def _truncated_t(self, t: Tensor) -> Tensor:
180198
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
181199
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
182200
t_trunc = self._truncated_t(t)
183-
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._s_shift_cosine
201+
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
184202

185203
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
186204
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
187205
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
188-
return 2 / math.pi * ops.arctan(ops.exp((2 * self._s_shift_cosine - log_snr_t) * 0.5))
206+
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
189207

190208
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
191209
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
@@ -202,7 +220,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
202220
return -factor * dsnr_dt
203221

204222
def get_config(self):
205-
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, s_shift_cosine=self._s_shift_cosine)
223+
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift)
206224

207225
@classmethod
208226
def from_config(cls, config, custom_objects=None):

0 commit comments

Comments
 (0)