Skip to content

Commit 01b33dc

Browse files
committed
fixes: use squared g, correct typo in _min_t
1 parent 2ce74f0 commit 01b33dc

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
6969
r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
7070
pass
7171

72-
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = True) -> tuple[Tensor, Tensor]:
73-
r"""Compute the drift and optionally the diffusion term for the reverse SDE.
72+
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
73+
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7474
Usually it can be derived from the derivative of the schedule:
7575
\beta(t) = d/dt log(1 + e^(-snr(t)))
7676
f(z, t) = -0.5 * \beta(t) * z
@@ -84,14 +84,14 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
8484
# Default implementation is to return the diffusion term only
8585
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
8686
if x is None: # return g only
87-
return ops.sqrt(beta)
87+
return beta
8888
if self.variance_type == "preserving":
8989
f = -0.5 * beta * x
9090
elif self.variance_type == "exploding":
9191
f = ops.zeros_like(beta)
9292
else:
9393
raise ValueError(f"Unknown variance type: {self.variance_type}")
94-
return f, ops.sqrt(beta)
94+
return f, beta
9595

9696
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
9797
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
@@ -144,7 +144,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
144144
self._log_snr_min = min_log_snr
145145
self._log_snr_max = max_log_snr
146146

147-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
147+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
148148
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
149149

150150
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
@@ -176,9 +176,9 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
176176
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
177177
Default is the likelihood weighting based on Song et al. (2021).
178178
"""
179-
g = self.get_drift_diffusion(log_snr_t=log_snr_t)
179+
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
180180
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
181-
return ops.square(g / sigma_t)
181+
return g_squared / ops.square(sigma_t)
182182

183183
def get_config(self):
184184
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max)
@@ -203,7 +203,7 @@ def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_co
203203
self._log_snr_max = max_log_snr
204204
self._s_shift_cosine = s_shift_cosine
205205

206-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
206+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
207207
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
208208

209209
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
@@ -254,7 +254,6 @@ class EDMNoiseSchedule(NoiseSchedule):
254254

255255
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
256256
super().__init__(name="edm_noise_schedule", variance_type="exploding")
257-
super().__init__(name="edm_noise_schedule")
258257
self.sigma_data = sigma_data
259258
self.sigma_max = sigma_max
260259
self.sigma_min = sigma_min
@@ -265,7 +264,7 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
265264
# convert EDM parameters to signal-to-noise ratio formulation
266265
self._log_snr_min = -2 * ops.log(sigma_max)
267266
self._log_snr_max = -2 * ops.log(sigma_min)
268-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
267+
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
269268
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
270269

271270
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
@@ -513,8 +512,8 @@ def velocity(
513512
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
514513

515514
# compute velocity for the ODE depending on the noise schedule
516-
f, g = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
517-
out = f - 0.5 * ops.square(g) * score
515+
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
516+
out = f - 0.5 * g_squared * score
518517

519518
# todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
520519
return out
@@ -680,5 +679,5 @@ def compute_metrics(
680679
# apply sample weight
681680
loss = weighted_mean(loss, sample_weight)
682681

683-
base_metrics = super().compute_metrics(x, conditions, sample_weight, stage)
682+
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
684683
return base_metrics | {"loss": loss}

0 commit comments

Comments
 (0)