@@ -257,7 +257,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
257257 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
258258 Default is the sigmoid weighting based on Kingma et al. (2023).
259259 """
260- return ops .sigmoid (- log_snr_t / 2 )
260+ return ops .sigmoid (- log_snr_t + 2 )
261261
262262 def get_config (self ):
263263 return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max , s_shift_cosine = self ._s_shift_cosine )
@@ -270,6 +270,7 @@ def from_config(cls, config, custom_objects=None):
270270@serializable
271271class EDMNoiseSchedule (NoiseSchedule ):
272272 """EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
273+ This should be used with the F-prediction type in the diffusion model.
273274
274275 [1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
275276 """
@@ -350,7 +351,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
350351
351352 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
352353 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
353- return ops .exp (- log_snr_t ) + 0.5 ** 2
354+ return ( ops .exp (- log_snr_t ) + ops . square ( self . sigma_data )) / ops . square ( self . sigma_data )
354355
355356
356357@serializable
@@ -432,6 +433,10 @@ def __init__(
432433 if prediction_type not in ["velocity" , "noise" , "F" ]: # F is EDM
433434 raise ValueError (f"Unknown prediction type: { prediction_type } " )
434435 self .prediction_type = prediction_type
436+ if noise_schedule .name == "edm_noise_schedule" and prediction_type != "F" :
437+ warnings .warn (
438+ "EDM noise schedule is build for F-prediction. Consider using F-prediction instead." ,
439+ )
435440
436441 # clipping of prediction (after it was transformed to x-prediction)
437442 self ._clip_min = - 5.0
0 commit comments