Skip to content

Commit 495ed29

Browse files
committed
fix validate noise schedule for training
1 parent 95ca126 commit 495ed29

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def scale_base_distribution(self):
5858
raise ValueError(f"Unknown variance type: {self.variance_type}")
5959

6060
@abstractmethod
61-
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
61+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
6262
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
6363
pass
6464

@@ -145,10 +145,10 @@ def validate(self):
145145
raise ValueError("t(0) must be finite.")
146146
if not ops.isfinite(self.get_t_from_log_snr(self._log_snr_min, training=training)):
147147
raise ValueError("t(1) must be finite.")
148-
if not ops.isfinite(self.derivative_log_snr(self._log_snr_max, training=training)):
149-
raise ValueError("dt/t log_snr(0) must be finite.")
150-
if not ops.isfinite(self.derivative_log_snr(self._log_snr_min, training=training)):
151-
raise ValueError("dt/t log_snr(1) must be finite.")
148+
if not ops.isfinite(self.derivative_log_snr(self._log_snr_max, training=False)):
149+
raise ValueError("dt/t log_snr(0) must be finite.")
150+
if not ops.isfinite(self.derivative_log_snr(self._log_snr_min, training=False)):
151+
raise ValueError("dt/t log_snr(1) must be finite.")
152152

153153

154154
@serializable
@@ -168,7 +168,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
168168
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
169169
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
170170

171-
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
171+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
172172
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
173173
t_trunc = self._t_min + (self._t_max - self._t_min) * t
174174
# SNR = -log(exp(t^2) - 1)
@@ -228,7 +228,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
228228
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
229229
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
230230

231-
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
231+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
232232
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
233233
t_trunc = self._t_min + (self._t_max - self._t_min) * t
234234
# SNR = -2 * log(tan(pi*t/2))
@@ -289,7 +289,7 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
289289
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
290290
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
291291

292-
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
292+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
293293
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
294294
t_trunc = self._t_min + (self._t_max - self._t_min) * t
295295
if training:

0 commit comments

Comments
 (0)