Skip to content

Commit 59a349b

Browse files
committed
minor change in diffusion weightings
1 parent 495ed29 commit 59a349b

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
257257
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
258258
Default is the sigmoid weighting based on Kingma et al. (2023).
259259
"""
260-
return ops.sigmoid(-log_snr_t / 2)
260+
return ops.sigmoid(-log_snr_t + 2)
261261

262262
def get_config(self):
263263
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max, s_shift_cosine=self._s_shift_cosine)
@@ -270,6 +270,7 @@ def from_config(cls, config, custom_objects=None):
270270
@serializable
271271
class EDMNoiseSchedule(NoiseSchedule):
272272
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
273+
This should be used with the F-prediction type in the diffusion model.
273274
274275
[1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
275276
"""
@@ -350,7 +351,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
350351

351352
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
352353
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
353-
return ops.exp(-log_snr_t) + 0.5**2
354+
return (ops.exp(-log_snr_t) + ops.square(self.sigma_data)) / ops.square(self.sigma_data)
354355

355356

356357
@serializable
@@ -432,6 +433,10 @@ def __init__(
432433
if prediction_type not in ["velocity", "noise", "F"]: # F is EDM
433434
raise ValueError(f"Unknown prediction type: {prediction_type}")
434435
self.prediction_type = prediction_type
436+
if noise_schedule.name == "edm_noise_schedule" and prediction_type != "F":
437+
warnings.warn(
438+
"EDM noise schedule is build for F-prediction. Consider using F-prediction instead.",
439+
)
435440

436441
# clipping of prediction (after it was transformed to x-prediction)
437442
self._clip_min = -5.0

0 commit comments

Comments
 (0)