@@ -81,10 +81,12 @@ def __init__(
8181
8282 self .vb_loss_weight = vb_loss_weight
8383
84- def model_predictions (self , x , t ):
84+ def model_predictions (self , x , t , clip_x_start = False ):
8585 model_output = self .model (x , t )
8686 model_output , pred_variance = model_output .chunk (2 , dim = 1 )
8787
88+ maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
89+
8890 if self .objective == 'pred_noise' :
8991 pred_noise = model_output
9092 x_start = self .predict_start_from_noise (x , t , model_output )
@@ -93,9 +95,11 @@ def model_predictions(self, x, t):
9395 pred_noise = self .predict_noise_from_start (x , t , model_output )
9496 x_start = model_output
9597
98+ x_start = maybe_clip (x_start )
99+
96100 return ModelPrediction (pred_noise , x_start , pred_variance )
97101
98- def p_mean_variance (self , * , x , t , clip_denoised , model_output = None ):
102+ def p_mean_variance (self , * , x , t , clip_denoised , model_output = None , ** kwargs ):
99103 model_output = default (model_output , lambda : self .model (x , t ))
100104 pred_noise , var_interp_frac_unnormalized = model_output .chunk (2 , dim = 1 )
101105
@@ -113,7 +117,7 @@ def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
113117
114118 model_mean , _ , _ = self .q_posterior (x_start , x , t )
115119
116- return model_mean , model_variance , model_log_variance
120+ return model_mean , model_variance , model_log_variance , x_start
117121
118122 def p_losses (self , x_start , t , noise = None , clip_denoised = False ):
119123 noise = default (noise , lambda : torch .randn_like (x_start ))
@@ -126,7 +130,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
126130 # calculating kl loss for learned variance (interpolation)
127131
128132 true_mean , _ , true_log_variance_clipped = self .q_posterior (x_start = x_start , x_t = x_t , t = t )
129- model_mean , _ , model_log_variance = self .p_mean_variance (x = x_t , t = t , clip_denoised = clip_denoised , model_output = model_output )
133+ model_mean , _ , model_log_variance , _ = self .p_mean_variance (x = x_t , t = t , clip_denoised = clip_denoised , model_output = model_output )
130134
131135 # kl loss with detached model predicted mean, for stability reasons as in paper
132136
0 commit comments