Skip to content

Commit 81b4a5d

Browse files
committed
fixes for learned gaussian diffusion
1 parent 0dee2d8 commit 81b4a5d

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

denoising_diffusion_pytorch/learned_gaussian_diffusion.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def __init__(
8181

8282
self.vb_loss_weight = vb_loss_weight
8383

84-
def model_predictions(self, x, t):
84+
def model_predictions(self, x, t, clip_x_start = False):
8585
model_output = self.model(x, t)
8686
model_output, pred_variance = model_output.chunk(2, dim = 1)
8787

88+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
89+
8890
if self.objective == 'pred_noise':
8991
pred_noise = model_output
9092
x_start = self.predict_start_from_noise(x, t, model_output)
@@ -93,9 +95,11 @@ def model_predictions(self, x, t):
9395
pred_noise = self.predict_noise_from_start(x, t, model_output)
9496
x_start = model_output
9597

98+
x_start = maybe_clip(x_start)
99+
96100
return ModelPrediction(pred_noise, x_start, pred_variance)
97101

98-
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
102+
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None, **kwargs):
99103
model_output = default(model_output, lambda: self.model(x, t))
100104
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
101105

@@ -113,7 +117,7 @@ def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
113117

114118
model_mean, _, _ = self.q_posterior(x_start, x, t)
115119

116-
return model_mean, model_variance, model_log_variance
120+
return model_mean, model_variance, model_log_variance, x_start
117121

118122
def p_losses(self, x_start, t, noise = None, clip_denoised = False):
119123
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -126,7 +130,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):
126130
# calculating kl loss for learned variance (interpolation)
127131

128132
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
129-
model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)
133+
model_mean, _, model_log_variance, _ = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)
130134

131135
# kl loss with detached model predicted mean, for stability reasons as in paper
132136

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.5'
1+
__version__ = '1.5.6'

0 commit comments

Comments
 (0)