@@ -27,6 +27,14 @@ class VarianceType(Enum):
2727 EXPLODING = "exploding"
2828
2929
30+ class PredictionType (Enum ):
31+ VELOCITY = "velocity"
32+ NOISE = "noise"
33+ X = "x"
34+ F = "F" # EDM
35+ SCORE = "score"
36+
37+
3038@serializable
3139class NoiseSchedule (ABC ):
3240 r"""Noise schedule for diffusion models. We follow the notation from [1].
@@ -45,11 +53,12 @@ class NoiseSchedule(ABC):
4553 Augmentation: Kingma et al. (2023)
4654 """
4755
48- def __init__ (self , name : str , variance_type : VarianceType ):
56+ def __init__ (self , name : str , variance_type : VarianceType , weighting : str = None ):
4957 self .name = name
5058 self .variance_type = variance_type # 'exploding' or 'preserving'
5159 self ._log_snr_min = - 15 # should be set in the subclasses
5260 self ._log_snr_max = 15 # should be set in the subclasses
61+ self .weighting = weighting
5362
5463 @abstractmethod
5564 def get_log_snr (self , t : Union [float , Tensor ], training : bool ) -> Tensor :
@@ -113,10 +122,18 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
113122 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
114123 Generally, weighting functions should be defined for a noise prediction loss.
115124 """
116- # sigmoid: ops.sigmoid(-log_snr_t + 2), based on Kingma et al. (2023)
117- # min-snr with gamma = 5, based on Hang et al. (2023)
118- # 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
119- return ops .ones_like (log_snr_t )
125+ if self .weighting is None :
126+ return ops .ones_like (log_snr_t )
127+ elif self .weighting == "sigmoid" :
128+ # sigmoid weighting based on Kingma et al. (2023)
129+ return ops .sigmoid (- log_snr_t + 2 )
130+ elif self .weighting == "likelihood_weighting" :
131+ # likelihood weighting based on Song et al. (2021)
132+ g_squared = self .get_drift_diffusion (log_snr_t = log_snr_t )
133+ sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[1 ]
134+ return g_squared / ops .square (sigma_t )
135+ else :
136+ raise ValueError (f"Unknown weighting type: { self .weighting } " )
120137
121138 def get_config (self ):
122139 return dict (name = self .name , variance_type = self .variance_type )
@@ -154,7 +171,9 @@ class LinearNoiseSchedule(NoiseSchedule):
154171 """
155172
156173 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
157- super ().__init__ (name = "linear_noise_schedule" , variance_type = VarianceType .PRESERVING )
174+ super ().__init__ (
175+ name = "linear_noise_schedule" , variance_type = VarianceType .PRESERVING , weighting = "likelihood_weighting"
176+ )
158177 self ._log_snr_min = min_log_snr
159178 self ._log_snr_max = max_log_snr
160179
@@ -190,14 +209,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
190209 factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
191210 return - factor * dsnr_dt
192211
193- def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
194- """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
195- Default is the likelihood weighting based on Song et al. (2021).
196- """
197- g_squared = self .get_drift_diffusion (log_snr_t = log_snr_t )
198- sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[1 ]
199- return g_squared / ops .square (sigma_t )
200-
201212 def get_config (self ):
202213 return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max )
203214
@@ -214,8 +225,10 @@ class CosineNoiseSchedule(NoiseSchedule):
214225 [1] Diffusion models beat gans on image synthesis: Dhariwal and Nichol (2022)
215226 """
216227
217- def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 ):
218- super ().__init__ (name = "cosine_noise_schedule" , variance_type = VarianceType .PRESERVING )
228+ def __init__ (
229+ self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 , weighting : str = "sigmoid"
230+ ):
231+ super ().__init__ (name = "cosine_noise_schedule" , variance_type = VarianceType .PRESERVING , weighting = weighting )
219232 self ._s_shift_cosine = s_shift_cosine
220233 self ._log_snr_min = min_log_snr
221234 self ._log_snr_max = max_log_snr
@@ -252,12 +265,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
252265 factor = ops .exp (- log_snr_t ) / (1 + ops .exp (- log_snr_t ))
253266 return - factor * dsnr_dt
254267
255- def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
256- """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
257- Default is the sigmoid weighting based on Kingma et al. (2023).
258- """
259- return ops .sigmoid (- log_snr_t + 2 )
260-
261268 def get_config (self ):
262269 return dict (min_log_snr = self ._log_snr_min , max_log_snr = self ._log_snr_max , s_shift_cosine = self ._s_shift_cosine )
263270
@@ -345,6 +352,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
345352
346353 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
347354 """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
355+ # for F-prediction: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
348356 return ops .exp (- log_snr_t ) / ops .square (self .sigma_data ) + 1
349357
350358 def get_config (self ):
@@ -384,7 +392,7 @@ def __init__(
384392 integrate_kwargs : dict [str , any ] = None ,
385393 subnet_kwargs : dict [str , any ] = None ,
386394 noise_schedule : str | NoiseSchedule = "cosine" ,
387- prediction_type : str = "velocity" ,
395+ prediction_type : PredictionType = "velocity" ,
388396 ** kwargs ,
389397 ):
390398 """
@@ -431,13 +439,21 @@ def __init__(
431439 # validate noise model
432440 self .noise_schedule .validate ()
433441
434- if prediction_type not in ["velocity" , "noise" , "F" ]: # F is EDM
442+ if prediction_type in [PredictionType . NOISE , PredictionType . VELOCITY , PredictionType . F ]: # F is EDM
435443 raise ValueError (f"Unknown prediction type: { prediction_type } " )
436- self .prediction_type = prediction_type
437- if noise_schedule .name == "edm_noise_schedule" and prediction_type != "F" :
444+ self ._prediction_type = prediction_type
445+ if noise_schedule .name == "edm_noise_schedule" and prediction_type != PredictionType . F :
438446 warnings .warn (
439447 "EDM noise schedule is build for F-prediction. Consider using F-prediction instead." ,
440448 )
449+ self ._loss_type = kwargs .get ("loss_type" , PredictionType .NOISE )
450+ if self ._loss_type not in [PredictionType .NOISE , PredictionType .VELOCITY , PredictionType .F ]:
451+ raise ValueError (f"Unknown loss type: { self ._loss_type } " )
452+ if self ._loss_type != PredictionType .NOISE :
453+ warnings .warn (
454+ "the standard schedules have weighting functions defined for the noise prediction loss. "
455+ "You might want to replace them, if you use a different loss function."
456+ )
441457
442458 # clipping of prediction (after it was transformed to x-prediction)
443459 self ._clip_min = - 5.0
@@ -489,7 +505,8 @@ def get_config(self):
489505 "subnet" : self .subnet ,
490506 "noise_schedule" : self .noise_schedule ,
491507 "integrate_kwargs" : self .integrate_kwargs ,
492- "prediction_type" : self .prediction_type ,
508+ "prediction_type" : self ._prediction_type ,
509+ "loss_type" : self ._loss_type ,
493510 }
494511 return base_config | serialize (config )
495512
@@ -501,18 +518,18 @@ def convert_prediction_to_x(
501518 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
502519 ) -> Tensor :
503520 """Convert the prediction of the neural network to the x space."""
504- if self .prediction_type == "velocity" :
521+ if self ._prediction_type == PredictionType . VELOCITY :
505522 # convert v into x
506523 x = alpha_t * z - sigma_t * pred
507- elif self .prediction_type == "noise" :
524+ elif self ._prediction_type == PredictionType . NOISE :
508525 # convert noise prediction into x
509526 x = (z - sigma_t * pred ) / alpha_t
510- elif self .prediction_type == "F" : # EDM
527+ elif self ._prediction_type == PredictionType . F : # EDM
511528 sigma_data = self .noise_schedule .sigma_data
512529 x1 = (sigma_data ** 2 * alpha_t ) / (ops .exp (- log_snr_t ) + sigma_data ** 2 )
513530 x2 = ops .exp (- log_snr_t / 2 ) * sigma_data / ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 )
514531 x = x1 * z + x2 * pred
515- elif self .prediction_type == "x" :
532+ elif self ._prediction_type == PredictionType . X :
516533 x = pred
517534 else : # "score"
518535 x = (z + sigma_t ** 2 * pred ) / alpha_t
@@ -757,10 +774,26 @@ def compute_metrics(
757774 pred = pred , z = diffused_x , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t , clip_x = False
758775 )
759776
760- # convert x to epsilon prediction
761- noise_pred = (diffused_x - alpha_t * x_pred ) / sigma_t
762777 # Calculate loss
763- loss = weights_for_snr * ops .mean ((noise_pred - eps_t ) ** 2 , axis = - 1 )
778+ if self ._loss_type == PredictionType .NOISE :
779+ # convert x to epsilon prediction
780+ noise_pred = (diffused_x - alpha_t * x_pred ) / sigma_t
781+ loss = weights_for_snr * ops .mean ((noise_pred - eps_t ) ** 2 , axis = - 1 )
782+ elif self ._loss_type == PredictionType .VELOCITY :
783+ # convert x to velocity prediction
784+ velocity_pred = (alpha_t * diffused_x - x_pred ) / sigma_t
785+ v_t = alpha_t * eps_t - sigma_t * x
786+ loss = weights_for_snr * ops .mean ((velocity_pred - v_t ) ** 2 , axis = - 1 )
787+ elif self ._loss_type == PredictionType .F :
788+ # convert x to F prediction
789+ sigma_data = self .noise_schedule .sigma_data
790+ x1 = ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 ) / (ops .exp (- log_snr_t / 2 ) * sigma_data )
791+ x2 = (sigma_data * alpha_t ) / (ops .exp (- log_snr_t / 2 ) * ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 ))
792+ f_pred = x1 * x_pred - x2 * diffused_x
793+ f_t = x1 * x - x2 * diffused_x
794+ loss = weights_for_snr * ops .mean ((f_pred - f_t ) ** 2 , axis = - 1 )
795+ else :
796+ raise ValueError (f"Unknown loss type: { self ._loss_type } " )
764797
765798 # apply sample weight
766799 loss = weighted_mean (loss , sample_weight )
0 commit comments