44import keras
55from keras import ops
66import warnings
7- from enum import Enum
87
98from bayesflow .utils .serialization import serialize , deserialize , serializable
109from bayesflow .types import Tensor , Shape
2221)
2322
2423
25- class VarianceType (Enum ):
26- PRESERVING = "preserving"
27- EXPLODING = "exploding"
28-
29-
30- class PredictionType (Enum ):
31- VELOCITY = "velocity"
32- NOISE = "noise"
33- X = "x"
34- F = "F" # EDM
35- SCORE = "score"
36-
37-
3824@serializable
3925class NoiseSchedule (ABC ):
4026 r"""Noise schedule for diffusion models. We follow the notation from [1].
@@ -53,7 +39,7 @@ class NoiseSchedule(ABC):
5339 Augmentation: Kingma et al. (2023)
5440 """
5541
56- def __init__ (self , name : str , variance_type : VarianceType , weighting : str = None ):
42+ def __init__ (self , name : str , variance_type : str , weighting : str = None ):
5743 self .name = name
5844 self .variance_type = variance_type # 'exploding' or 'preserving'
5945 self ._log_snr_min = - 15 # should be set in the subclasses
@@ -90,9 +76,9 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
9076 beta = self .derivative_log_snr (log_snr_t = log_snr_t , training = training )
9177 if x is None : # return g^2 only
9278 return beta
93- if self .variance_type == VarianceType . PRESERVING :
79+ if self .variance_type == "preserving" :
9480 f = - 0.5 * beta * x
95- elif self .variance_type == VarianceType . EXPLODING :
81+ elif self .variance_type == "exploding" :
9682 f = ops .zeros_like (beta )
9783 else :
9884 raise ValueError (f"Unknown variance type: { self .variance_type } " )
@@ -106,11 +92,11 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
10692 sigma(t) = sqrt(sigmoid(-log_snr_t))
10793 For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
10894 """
109- if self .variance_type == VarianceType . PRESERVING :
95+ if self .variance_type == "preserving" :
11096 # variance preserving schedule
11197 alpha_t = ops .sqrt (ops .sigmoid (log_snr_t ))
11298 sigma_t = ops .sqrt (ops .sigmoid (- log_snr_t ))
113- elif self .variance_type == VarianceType . EXPLODING :
99+ elif self .variance_type == "exploding" :
114100 # variance exploding schedule
115101 alpha_t = ops .ones_like (log_snr_t )
116102 sigma_t = ops .sqrt (ops .exp (- log_snr_t ))
@@ -171,9 +157,7 @@ class LinearNoiseSchedule(NoiseSchedule):
171157 """
172158
173159 def __init__ (self , min_log_snr : float = - 15 , max_log_snr : float = 15 ):
174- super ().__init__ (
175- name = "linear_noise_schedule" , variance_type = VarianceType .PRESERVING , weighting = "likelihood_weighting"
176- )
160+ super ().__init__ (name = "linear_noise_schedule" , variance_type = "preserving" , weighting = "likelihood_weighting" )
177161 self ._log_snr_min = min_log_snr
178162 self ._log_snr_max = max_log_snr
179163
@@ -228,7 +212,7 @@ class CosineNoiseSchedule(NoiseSchedule):
228212 def __init__ (
229213 self , min_log_snr : float = - 15 , max_log_snr : float = 15 , s_shift_cosine : float = 0.0 , weighting : str = "sigmoid"
230214 ):
231- super ().__init__ (name = "cosine_noise_schedule" , variance_type = VarianceType . PRESERVING , weighting = weighting )
215+ super ().__init__ (name = "cosine_noise_schedule" , variance_type = "preserving" , weighting = weighting )
232216 self ._s_shift_cosine = s_shift_cosine
233217 self ._log_snr_min = min_log_snr
234218 self ._log_snr_max = max_log_snr
@@ -283,7 +267,7 @@ class EDMNoiseSchedule(NoiseSchedule):
283267 """
284268
285269 def __init__ (self , sigma_data : float = 1.0 , sigma_min : float = 1e-4 , sigma_max : float = 80.0 ):
286- super ().__init__ (name = "edm_noise_schedule" , variance_type = VarianceType . PRESERVING )
270+ super ().__init__ (name = "edm_noise_schedule" , variance_type = "preserving" )
287271 self .sigma_data = sigma_data
288272 # training settings
289273 self .p_mean = - 1.2
@@ -392,7 +376,7 @@ def __init__(
392376 integrate_kwargs : dict [str , any ] = None ,
393377 subnet_kwargs : dict [str , any ] = None ,
394378 noise_schedule : str | NoiseSchedule = "cosine" ,
395- prediction_type : PredictionType = "velocity" ,
379+ prediction_type : str = "velocity" ,
396380 ** kwargs ,
397381 ):
398382 """
@@ -439,17 +423,17 @@ def __init__(
439423 # validate noise model
440424 self .noise_schedule .validate ()
441425
442- if prediction_type in [PredictionType . NOISE , PredictionType . VELOCITY , PredictionType . F ]: # F is EDM
426+ if prediction_type not in ["noise" , "velocity" , "F" ]: # F is EDM
443427 raise ValueError (f"Unknown prediction type: { prediction_type } " )
444428 self ._prediction_type = prediction_type
445- if noise_schedule .name == "edm_noise_schedule" and prediction_type != PredictionType . F :
429+ if noise_schedule .name == "edm_noise_schedule" and prediction_type != "F" :
446430 warnings .warn (
447431 "EDM noise schedule is build for F-prediction. Consider using F-prediction instead." ,
448432 )
449- self ._loss_type = kwargs .get ("loss_type" , PredictionType . NOISE )
450- if self ._loss_type not in [PredictionType . NOISE , PredictionType . VELOCITY , PredictionType . F ]:
433+ self ._loss_type = kwargs .get ("loss_type" , "noise" )
434+ if self ._loss_type not in ["noise" , "velocity" , "F" ]:
451435 raise ValueError (f"Unknown loss type: { self ._loss_type } " )
452- if self ._loss_type != PredictionType . NOISE :
436+ if self ._loss_type != "noise" :
453437 warnings .warn (
454438 "the standard schedules have weighting functions defined for the noise prediction loss. "
455439 "You might want to replace them, if you use a different loss function."
@@ -518,18 +502,18 @@ def convert_prediction_to_x(
518502 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
519503 ) -> Tensor :
520504 """Convert the prediction of the neural network to the x space."""
521- if self ._prediction_type == PredictionType . VELOCITY :
505+ if self ._prediction_type == "velocity" :
522506 # convert v into x
523507 x = alpha_t * z - sigma_t * pred
524- elif self ._prediction_type == PredictionType . NOISE :
508+ elif self ._prediction_type == "noise" :
525509 # convert noise prediction into x
526510 x = (z - sigma_t * pred ) / alpha_t
527- elif self ._prediction_type == PredictionType . F : # EDM
511+ elif self ._prediction_type == "F" : # EDM
528512 sigma_data = self .noise_schedule .sigma_data
529513 x1 = (sigma_data ** 2 * alpha_t ) / (ops .exp (- log_snr_t ) + sigma_data ** 2 )
530514 x2 = ops .exp (- log_snr_t / 2 ) * sigma_data / ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 )
531515 x = x1 * z + x2 * pred
532- elif self ._prediction_type == PredictionType . X :
516+ elif self ._prediction_type == "x" :
533517 x = pred
534518 else : # "score"
535519 x = (z + sigma_t ** 2 * pred ) / alpha_t
@@ -775,16 +759,16 @@ def compute_metrics(
775759 )
776760
777761 # Calculate loss
778- if self ._loss_type == PredictionType . NOISE :
762+ if self ._loss_type == "noise" :
779763 # convert x to epsilon prediction
780764 noise_pred = (diffused_x - alpha_t * x_pred ) / sigma_t
781765 loss = weights_for_snr * ops .mean ((noise_pred - eps_t ) ** 2 , axis = - 1 )
782- elif self ._loss_type == PredictionType . VELOCITY :
766+ elif self ._loss_type == "velocity" :
783767 # convert x to velocity prediction
784768 velocity_pred = (alpha_t * diffused_x - x_pred ) / sigma_t
785769 v_t = alpha_t * eps_t - sigma_t * x
786770 loss = weights_for_snr * ops .mean ((velocity_pred - v_t ) ** 2 , axis = - 1 )
787- elif self ._loss_type == PredictionType . F :
771+ elif self ._loss_type == "F" :
788772 # convert x to F prediction
789773 sigma_data = self .noise_schedule .sigma_data
790774 x1 = ops .sqrt (ops .exp (- log_snr_t ) + sigma_data ** 2 ) / (ops .exp (- log_snr_t / 2 ) * sigma_data )
0 commit comments