Skip to content

Commit 9b520bc

Browse files
committed
fix mapping min/max snr to t_min/max
1 parent ca52fc0 commit 9b520bc

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,15 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
146146
self._log_snr_min = min_log_snr
147147
self._log_snr_max = max_log_snr
148148

149-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
150-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
149+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
150+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
151151

152152
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
153153
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
154154
t_trunc = self._t_min + (self._t_max - self._t_min) * t
155155
# SNR = -log(exp(t^2) - 1)
156-
return -ops.log(ops.exp(ops.square(t_trunc)) - 1)
156+
# equivalent, but more stable: -t^2 - log(1 - exp(-t^2))
157+
return -ops.square(t_trunc) - ops.log(1 - ops.exp(-ops.square(t_trunc)))
157158

158159
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
159160
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
@@ -205,8 +206,8 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
205206
self._log_snr_max = max_log_snr
206207
self._s_shift_cosine = s_shift_cosine
207208

208-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
209-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
209+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
210+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
210211

211212
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
212213
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -266,8 +267,8 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
266267
# convert EDM parameters to signal-to-noise ratio formulation
267268
self._log_snr_min = -2 * ops.log(sigma_max)
268269
self._log_snr_max = -2 * ops.log(sigma_min)
269-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
270-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
270+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
271+
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
271272

272273
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
273274
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""

0 commit comments

Comments
 (0)