Skip to content

Commit e4675da

Browse files
committed
add a wrapper for beta schedule functions to enforce zero terminal snr as algorithm 1 in https://arxiv.org/abs/2305.08891, also add the rescaling of classifier free guidance as proposed in that paper
1 parent 6d15667 commit e4675da

File tree

8 files changed

+104
-84
lines changed

8 files changed

+104
-84
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ model = Unet(
3838
diffusion = GaussianDiffusion(
3939
model,
4040
image_size = 128,
41-
timesteps = 1000, # number of steps
42-
loss_type = 'l1' # L1 or L2
41+
timesteps = 1000 # number of steps
4342
)
4443

4544
training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
@@ -65,8 +64,7 @@ diffusion = GaussianDiffusion(
6564
model,
6665
image_size = 128,
6766
timesteps = 1000, # number of steps
68-
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
69-
loss_type = 'l1' # L1 or L2
67+
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
7068
)
7169

7270
trainer = Trainer(
@@ -148,9 +146,11 @@ sampled_seq = diffusion.sample(batch_size = 4)
148146
sampled_seq.shape # (4, 32, 128)
149147

150148
```
149+
151150
`Trainer1D` does not evaluate the generated samples in any way since the type of data is not known.
152151
You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
153152
`pip install -e .`.
153+
154154
## Citations
155155

156156
```bibtex

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from pathlib import Path
44
from random import random
5-
from functools import partial
5+
from functools import partial, wraps
66
from collections import namedtuple
77
from multiprocessing import cpu_count
88

@@ -375,6 +375,7 @@ def forward_with_cond_scale(
375375
self,
376376
*args,
377377
cond_scale = 1.,
378+
rescale_phi = 0.,
378379
**kwargs
379380
):
380381
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
@@ -383,7 +384,18 @@ def forward_with_cond_scale(
383384
return logits
384385

385386
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
386-
return null_logits + (logits - null_logits) * cond_scale
387+
scaled_logits = null_logits + (logits - null_logits) * cond_scale
388+
389+
if rescale_phi <= 0:
390+
return scaled_logits
391+
392+
# rescaling proposed in https://arxiv.org/abs/2305.08891 to prevent over-saturation
393+
# works for both pixel and latent space, as opposed to only pixel space with the technique from Imagen
394+
# they found 0.7 to work well empirically with a conditional scale of 6.
395+
396+
std_fn = partial(torch.std, dim = tuple(range(1, scaled_logits.ndim)), keepdim = True)
397+
rescaled_logits = scaled_logits * (std_fn(logits) / std_fn(scaled_logits))
398+
return rescaled_logits * rescale_phi + (1 - rescale_phi) * scaled_logits
387399

388400
def forward(
389401
self,
@@ -457,6 +469,33 @@ def extract(a, t, x_shape):
457469
out = a.gather(-1, t)
458470
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
459471

472+
def enforce_zero_terminal_snr(schedule_fn):
473+
# algorithm 1 in https://arxiv.org/abs/2305.08891
474+
475+
@wraps(schedule_fn)
476+
def inner(*args, **kwargs):
477+
betas = schedule_fn(*args, **kwargs)
478+
alphas = 1. - betas
479+
480+
alphas_cumprod = torch.cumprod(alphas, dim = 0)
481+
alphas_cumprod = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
482+
483+
alphas_cumprod_sqrt = torch.sqrt(alphas_cumprod)
484+
485+
terminal_snr = alphas_cumprod_sqrt[-1].clone()
486+
487+
alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
488+
alphas_cumprod_sqrt *= 1. / (1. - terminal_snr)
489+
490+
alphas_cumprod = alphas_cumprod_sqrt ** 2
491+
alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
492+
betas = 1. - alphas
493+
494+
return betas
495+
496+
return inner
497+
498+
@enforce_zero_terminal_snr
460499
def linear_beta_schedule(timesteps):
461500
scale = 1000 / timesteps
462501
beta_start = scale * 0.0001
@@ -473,7 +512,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
473512
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
474513
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
475514
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
476-
return torch.clip(betas, 0, 0.999)
515+
return torch.clip(betas, 0, 1.)
477516

478517
class GaussianDiffusion(nn.Module):
479518
def __init__(
@@ -606,8 +645,8 @@ def q_posterior(self, x_start, x_t, t):
606645
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
607646
return posterior_mean, posterior_variance, posterior_log_variance_clipped
608647

609-
def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False):
610-
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale)
648+
def model_predictions(self, x, t, classes, cond_scale = 6., rescale_phi = 0.7, clip_x_start = False):
649+
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescale_phi = rescale_phi)
611650
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
612651

613652
if self.objective == 'pred_noise':
@@ -639,7 +678,7 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
639678
return model_mean, posterior_variance, posterior_log_variance, x_start
640679

641680
@torch.no_grad()
642-
def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
681+
def p_sample(self, x, t: int, classes, cond_scale = 6., rescale_phi = 0.7, clip_denoised = True):
643682
b, *_, device = *x.shape, x.device
644683
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
645684
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, clip_denoised = clip_denoised)
@@ -648,7 +687,7 @@ def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
648687
return pred_img, x_start
649688

650689
@torch.no_grad()
651-
def p_sample_loop(self, classes, shape, cond_scale = 3.):
690+
def p_sample_loop(self, classes, shape, cond_scale = 6., rescale_phi = 0.7):
652691
batch, device = shape[0], self.betas.device
653692

654693
img = torch.randn(shape, device=device)
@@ -662,7 +701,7 @@ def p_sample_loop(self, classes, shape, cond_scale = 3.):
662701
return img
663702

664703
@torch.no_grad()
665-
def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
704+
def ddim_sample(self, classes, shape, cond_scale = 6., rescale_phi = 0.7, clip_denoised = True):
666705
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
667706

668707
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -726,15 +765,6 @@ def q_sample(self, x_start, t, noise=None):
726765
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
727766
)
728767

729-
@property
730-
def loss_fn(self):
731-
if self.loss_type == 'l1':
732-
return F.l1_loss
733-
elif self.loss_type == 'l2':
734-
return F.mse_loss
735-
else:
736-
raise ValueError(f'invalid loss type {self.loss_type}')
737-
738768
def p_losses(self, x_start, t, *, classes, noise = None):
739769
b, c, h, w = x_start.shape
740770
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -757,7 +787,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
757787
else:
758788
raise ValueError(f'unknown objective {self.objective}')
759789

760-
loss = self.loss_fn(model_out, target, reduction = 'none')
790+
loss = F.mse_loss(model_out, target, reduction = 'none')
761791
loss = reduce(loss, 'b ... -> b (...)', 'mean')
762792

763793
loss = loss * extract(self.loss_weight, t, loss.shape)

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def __init__(
116116
*,
117117
image_size,
118118
channels = 3,
119-
loss_type = 'l1',
120119
noise_schedule = 'linear',
121120
num_sample_steps = 500,
122121
clip_sample_denoised = True,
@@ -138,8 +137,6 @@ def __init__(
138137

139138
# continuous noise schedule related stuff
140139

141-
self.loss_type = loss_type
142-
143140
if noise_schedule == 'linear':
144141
self.log_snr = beta_linear_log_snr
145142
elif noise_schedule == 'cosine':
@@ -170,15 +167,6 @@ def __init__(
170167
def device(self):
171168
return next(self.model.parameters()).device
172169

173-
@property
174-
def loss_fn(self):
175-
if self.loss_type == 'l1':
176-
return F.l1_loss
177-
elif self.loss_type == 'l2':
178-
return F.mse_loss
179-
else:
180-
raise ValueError(f'invalid loss type {self.loss_type}')
181-
182170
def p_mean_variance(self, x, time, time_next):
183171
# reviewer found an error in the equation in the paper (missing sigma)
184172
# following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt
@@ -266,7 +254,7 @@ def p_losses(self, x_start, times, noise = None):
266254
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
267255
model_out = self.model(x, log_snr)
268256

269-
losses = self.loss_fn(model_out, noise, reduction = 'none')
257+
losses = F.mse_loss(model_out, noise, reduction = 'none')
270258
losses = reduce(losses, 'b ... -> b', 'mean')
271259

272260
if self.min_snr_loss_weight:

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from pathlib import Path
44
from random import random
5-
from functools import partial
5+
from functools import partial, wraps
66
from collections import namedtuple
77
from multiprocessing import cpu_count
88

@@ -68,6 +68,11 @@ def convert_image_to_fn(img_type, image):
6868
return image.convert(img_type)
6969
return image
7070

71+
def extract(a, t, x_shape):
72+
b, *_ = t.shape
73+
out = a.gather(-1, t)
74+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
75+
7176
# normalization functions
7277

7378
def normalize_to_neg_one_to_one(img):
@@ -398,13 +403,35 @@ def forward(self, x, time, x_self_cond = None):
398403
x = self.final_res_block(x, t)
399404
return self.final_conv(x)
400405

401-
# gaussian diffusion trainer class
406+
# scheduling functions
402407

403-
def extract(a, t, x_shape):
404-
b, *_ = t.shape
405-
out = a.gather(-1, t)
406-
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
408+
def enforce_zero_terminal_snr(schedule_fn):
409+
# algorithm 1 in https://arxiv.org/abs/2305.08891
410+
411+
@wraps(schedule_fn)
412+
def inner(*args, **kwargs):
413+
betas = schedule_fn(*args, **kwargs)
414+
alphas = 1. - betas
415+
416+
alphas_cumprod = torch.cumprod(alphas, dim = 0)
417+
alphas_cumprod = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
407418

419+
alphas_cumprod_sqrt = torch.sqrt(alphas_cumprod)
420+
421+
terminal_snr = alphas_cumprod_sqrt[-1].clone()
422+
423+
alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
424+
alphas_cumprod_sqrt *= 1. / (1. - terminal_snr)
425+
426+
alphas_cumprod = alphas_cumprod_sqrt ** 2
427+
alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
428+
betas = 1. - alphas
429+
430+
return betas
431+
432+
return inner
433+
434+
@enforce_zero_terminal_snr
408435
def linear_beta_schedule(timesteps):
409436
"""
410437
linear schedule, proposed in original ddpm paper
@@ -426,6 +453,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
426453
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
427454
return torch.clip(betas, 0, 1.)
428455

456+
@enforce_zero_terminal_snr
429457
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
430458
"""
431459
sigmoid schedule
@@ -441,6 +469,8 @@ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1
441469
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
442470
return torch.clip(betas, 0, 1.)
443471

472+
# gaussian diffusion trainer class
473+
444474
class GaussianDiffusion(nn.Module):
445475
def __init__(
446476
self,
@@ -449,7 +479,6 @@ def __init__(
449479
image_size,
450480
timesteps = 1000,
451481
sampling_timesteps = None,
452-
loss_type = 'l1',
453482
objective = 'pred_noise',
454483
beta_schedule = 'sigmoid',
455484
schedule_fn_kwargs = dict(),
@@ -473,7 +502,9 @@ def __init__(
473502

474503
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
475504

476-
if beta_schedule == 'linear':
505+
if callable(beta_schedule):
506+
beta_schedule_fn = beta_schedule
507+
elif beta_schedule == 'linear':
477508
beta_schedule_fn = linear_beta_schedule
478509
elif beta_schedule == 'cosine':
479510
beta_schedule_fn = cosine_beta_schedule
@@ -485,12 +516,11 @@ def __init__(
485516
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
486517

487518
alphas = 1. - betas
488-
alphas_cumprod = torch.cumprod(alphas, dim=0)
519+
alphas_cumprod = torch.cumprod(alphas, dim = 0)
489520
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
490521

491522
timesteps, = betas.shape
492523
self.num_timesteps = int(timesteps)
493-
self.loss_type = loss_type
494524

495525
# sampling related parameters
496526

@@ -511,6 +541,10 @@ def __init__(
511541
# calculations for diffusion q(x_t | x_{t-1}) and others
512542

513543
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
544+
545+
terminal_snr = self.sqrt_alphas_cumprod[-1]
546+
assert terminal_snr == 0, f'non-zero terminal SNR detected ({terminal_snr:.6f}), from https://arxiv.org/abs/2305.08891 paper - you can wrap your schedule function with `enforce_zero_terminal_snr` decorator'
547+
514548
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
515549
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
516550
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
@@ -725,15 +759,6 @@ def q_sample(self, x_start, t, noise=None):
725759
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
726760
)
727761

728-
@property
729-
def loss_fn(self):
730-
if self.loss_type == 'l1':
731-
return F.l1_loss
732-
elif self.loss_type == 'l2':
733-
return F.mse_loss
734-
else:
735-
raise ValueError(f'invalid loss type {self.loss_type}')
736-
737762
def p_losses(self, x_start, t, noise = None):
738763
b, c, h, w = x_start.shape
739764
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -766,7 +791,7 @@ def p_losses(self, x_start, t, noise = None):
766791
else:
767792
raise ValueError(f'unknown objective {self.objective}')
768793

769-
loss = self.loss_fn(model_out, target, reduction = 'none')
794+
loss = F.mse_loss(model_out, target, reduction = 'none')
770795
loss = reduce(loss, 'b ... -> b (...)', 'mean')
771796

772797
loss = loss * extract(self.loss_weight, t, loss.shape)

0 commit comments

Comments
 (0)