@@ -44,6 +44,7 @@ def __init__(self, name: str, variance_type: str):
4444 self .variance_type = variance_type # 'exploding' or 'preserving'
4545 self ._log_snr_min = - 15 # should be set in the subclasses
4646 self ._log_snr_max = 15 # should be set in the subclasses
47+ self .sigma_data = 1.0
4748
4849 @property
4950 def scale_base_distribution (self ):
@@ -381,7 +382,7 @@ def __init__(
381382 integrate_kwargs : dict [str , any ] = None ,
382383 subnet_kwargs : dict [str , any ] = None ,
383384 noise_schedule : str | NoiseSchedule = "cosine" ,
384- prediction_type : str = "v " ,
385+ prediction_type : str = "velocity " ,
385386 ** kwargs ,
386387 ):
387388 """
@@ -406,7 +407,8 @@ def __init__(
406407 The noise schedule used for the diffusion process. Can be "linear", "cosine", or "edm".
407408 Default is "cosine".
408409 prediction_type: str, optional
409- The type of prediction used in the diffusion model. Can be "eps", "v" or "F" (EDM). Default is "v".
410+ The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
411+ Default is "velocity".
410412 **kwargs
411413 Additional keyword arguments passed to the subnet and other components.
412414 """
@@ -427,7 +429,7 @@ def __init__(
427429 # validate noise model
428430 self .noise_schedule .validate ()
429431
430- if prediction_type not in ["eps " , "v " , "F" ]: # F is EDM
432+ if prediction_type not in ["velocity " , "noise " , "F" ]: # F is EDM
431433 raise ValueError (f"Unknown prediction type: { prediction_type } " )
432434 self .prediction_type = prediction_type
433435
@@ -496,10 +498,10 @@ def convert_prediction_to_x(
496498 self , pred : Tensor , z : Tensor , alpha_t : Tensor , sigma_t : Tensor , log_snr_t : Tensor , clip_x : bool
497499 ) -> Tensor :
498500 """Convert the prediction of the neural network to the x space."""
499- if self .prediction_type == "v " :
501+ if self .prediction_type == "velocity " :
500502 # convert v into x
501503 x = alpha_t * z - sigma_t * pred
502- elif self .prediction_type == "eps " :
504+ elif self .prediction_type == "noise " :
503505 # convert noise prediction into x
504506 x = (z - sigma_t * pred ) / alpha_t
505507 elif self .prediction_type == "x" :
@@ -700,11 +702,11 @@ def compute_metrics(
700702 pred = pred , z = diffused_x , alpha_t = alpha_t , sigma_t = sigma_t , log_snr_t = log_snr_t , clip_x = False
701703 )
702704 # convert x to epsilon prediction
703- out = (alpha_t * diffused_x - x_pred ) / sigma_t
705+ noise_pred = (alpha_t * diffused_x - x_pred ) / sigma_t
704706
705707 # Calculate loss based on noise prediction
706708 weights_for_snr = self .noise_schedule .get_weights_for_snr (log_snr_t = log_snr_t )
707- loss = weights_for_snr * ops .mean ((out - eps_t ) ** 2 , axis = - 1 )
709+ loss = weights_for_snr * ops .mean ((noise_pred - eps_t ) ** 2 , axis = - 1 )
708710
709711 # apply sample weight
710712 loss = weighted_mean (loss , sample_weight )
0 commit comments