Skip to content

Commit 196683c

Browse files
committed
EDM training bounds
1 parent 194a503 commit 196683c

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class EDMNoiseSchedule(NoiseSchedule):
276276
[1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
277277
"""
278278

279-
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
279+
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0):
280280
super().__init__(name="edm_noise_schedule", variance_type="exploding")
281281
self.sigma_data = sigma_data
282282
# training settings
@@ -291,26 +291,25 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
291291
self._log_snr_min = -2 * ops.log(sigma_max)
292292
self._log_snr_max = -2 * ops.log(sigma_min)
293293
# t is not truncated for EDM by definition of the sampling schedule
294-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=False)
295-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=False)
294+
# training bounds are not so important, but should be set to avoid numerical issues
295+
self._log_snr_min_training = self._log_snr_min * 2 # one is never sampler during training
296+
self._log_snr_max_training = self._log_snr_max * 2 # 0 is almost surely never sampled during training
296297

297298
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
298299
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
299-
t_trunc = self._t_min + (self._t_max - self._t_min) * t
300300
if training:
301301
# SNR = -dist.icdf(t_trunc)
302302
loc = -2 * self.p_mean
303303
scale = 2 * self.p_std
304-
x = t_trunc
305-
snr = -(loc + scale * ops.erfinv(2 * x - 1) * math.sqrt(2))
306-
snr = keras.ops.clip(snr, x_min=self._log_snr_min, x_max=self._log_snr_max)
304+
snr = -(loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2))
305+
snr = keras.ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training)
307306
else: # sampling
308307
snr = (
309308
-2
310309
* self.rho
311310
* ops.log(
312311
self.sigma_max ** (1 / self.rho)
313-
+ (1 - t_trunc) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
312+
+ (1 - t) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
314313
)
315314
)
316315
return snr
@@ -338,20 +337,18 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
338337
raise NotImplementedError("Derivative of log SNR is not implemented for training mode.")
339338
# sampling mode
340339
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)
341-
t_trunc = self._t_min + (self._t_max - self._t_min) * t
342340

343341
# SNR = -2*rho*log(s_max + (1 - x)*(s_min - s_max))
344342
s_max = self.sigma_max ** (1 / self.rho)
345343
s_min = self.sigma_min ** (1 / self.rho)
346-
u = s_max + (1 - t_trunc) * (s_min - s_max)
344+
u = s_max + (1 - t) * (s_min - s_max)
347345
# d/dx snr = 2*rho*(s_min - s_max) / u
348346
dsnr_dx = 2 * self.rho * (s_min - s_max) / u
349347

350348
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
351349
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
352-
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
353350
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
354-
return -factor * dsnr_dt
351+
return -factor * dsnr_dx
355352

356353
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
357354
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""

0 commit comments

Comments
 (0)