Skip to content

Commit cbd3568

Browse files
committed
swap mapping log_snr_min/max to t_min/max
1 parent ca52fc0 commit cbd3568

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
146146
self._log_snr_min = min_log_snr
147147
self._log_snr_max = max_log_snr
148148

149-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
150-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
149+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
150+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
151151

152152
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
153153
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -205,8 +205,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
205205
self._log_snr_max = max_log_snr
206206
self._s_shift_cosine = s_shift_cosine
207207

208-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
209-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
208+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
209+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
210210

211211
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
212212
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -266,8 +266,8 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
266266
# convert EDM parameters to signal-to-noise ratio formulation
267267
self._log_snr_min = -2 * ops.log(sigma_max)
268268
self._log_snr_max = -2 * ops.log(sigma_min)
269-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
270-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
269+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
270+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
271271

272272
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
273273
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -478,7 +478,7 @@ def convert_prediction_to_x(
478478
if self.prediction_type == "v":
479479
# convert v into x
480480
x = alpha_t * z - sigma_t * pred
481-
elif self.prediction_type == "e":
481+
elif self.prediction_type == "eps":
482482
# convert noise prediction into x
483483
x = (z - sigma_t * pred) / alpha_t
484484
elif self.prediction_type == "x":
@@ -552,8 +552,8 @@ def _forward(
552552
) -> Tensor | tuple[Tensor, Tensor]:
553553
integrate_kwargs = (
554554
{
555-
"start_time": 1.0,
556-
"stop_time": 0.0,
555+
"start_time": 0.0,
556+
"stop_time": 1.0,
557557
}
558558
| self.integrate_kwargs
559559
| kwargs
@@ -601,8 +601,8 @@ def _inverse(
601601
) -> Tensor | tuple[Tensor, Tensor]:
602602
integrate_kwargs = (
603603
{
604-
"start_time": 0.0,
605-
"stop_time": 1.0,
604+
"start_time": 1.0,
605+
"stop_time": 0.0,
606606
}
607607
| self.integrate_kwargs
608608
| kwargs

0 commit comments

Comments
 (0)