Skip to content

Commit 77e8405

Browse files
committed
Revert "add a wrapper for beta schedule functions to enforce zero terminal snr as algorithm 1 in https://arxiv.org/abs/2305.08891, also add the rescaling of classifier free guidance as proposed in that paper"
This reverts commit e4675da.
1 parent e4675da commit 77e8405

File tree

8 files changed

+84
-104
lines changed

8 files changed

+84
-104
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ model = Unet(
3838
diffusion = GaussianDiffusion(
3939
model,
4040
image_size = 128,
41-
timesteps = 1000 # number of steps
41+
timesteps = 1000, # number of steps
42+
loss_type = 'l1' # L1 or L2
4243
)
4344

4445
training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
@@ -64,7 +65,8 @@ diffusion = GaussianDiffusion(
6465
model,
6566
image_size = 128,
6667
timesteps = 1000, # number of steps
67-
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
68+
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
69+
loss_type = 'l1' # L1 or L2
6870
)
6971

7072
trainer = Trainer(
@@ -146,11 +148,9 @@ sampled_seq = diffusion.sample(batch_size = 4)
146148
sampled_seq.shape # (4, 32, 128)
147149

148150
```
149-
150151
`Trainer1D` does not evaluate the generated samples in any way since the type of data is not known.
151152
You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
152153
`pip install -e .`.
153-
154154
## Citations
155155

156156
```bibtex

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from pathlib import Path
44
from random import random
5-
from functools import partial, wraps
5+
from functools import partial
66
from collections import namedtuple
77
from multiprocessing import cpu_count
88

@@ -375,7 +375,6 @@ def forward_with_cond_scale(
375375
self,
376376
*args,
377377
cond_scale = 1.,
378-
rescale_phi = 0.,
379378
**kwargs
380379
):
381380
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
@@ -384,18 +383,7 @@ def forward_with_cond_scale(
384383
return logits
385384

386385
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
387-
scaled_logits = null_logits + (logits - null_logits) * cond_scale
388-
389-
if rescale_phi <= 0:
390-
return scaled_logits
391-
392-
# rescaling proposed in https://arxiv.org/abs/2305.08891 to prevent over-saturation
393-
# works for both pixel and latent space, as opposed to only pixel space with the technique from Imagen
394-
# they found 0.7 to work well empirically with a conditional scale of 6.
395-
396-
std_fn = partial(torch.std, dim = tuple(range(1, scaled_logits.ndim)), keepdim = True)
397-
rescaled_logits = scaled_logits * (std_fn(logits) / std_fn(scaled_logits))
398-
return rescaled_logits * rescale_phi + (1 - rescale_phi) * scaled_logits
386+
return null_logits + (logits - null_logits) * cond_scale
399387

400388
def forward(
401389
self,
@@ -469,33 +457,6 @@ def extract(a, t, x_shape):
469457
out = a.gather(-1, t)
470458
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
471459

472-
def enforce_zero_terminal_snr(schedule_fn):
473-
# algorithm 1 in https://arxiv.org/abs/2305.08891
474-
475-
@wraps(schedule_fn)
476-
def inner(*args, **kwargs):
477-
betas = schedule_fn(*args, **kwargs)
478-
alphas = 1. - betas
479-
480-
alphas_cumprod = torch.cumprod(alphas, dim = 0)
481-
alphas_cumprod = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
482-
483-
alphas_cumprod_sqrt = torch.sqrt(alphas_cumprod)
484-
485-
terminal_snr = alphas_cumprod_sqrt[-1].clone()
486-
487-
alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
488-
alphas_cumprod_sqrt *= 1. / (1. - terminal_snr)
489-
490-
alphas_cumprod = alphas_cumprod_sqrt ** 2
491-
alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
492-
betas = 1. - alphas
493-
494-
return betas
495-
496-
return inner
497-
498-
@enforce_zero_terminal_snr
499460
def linear_beta_schedule(timesteps):
500461
scale = 1000 / timesteps
501462
beta_start = scale * 0.0001
@@ -512,7 +473,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
512473
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
513474
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
514475
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
515-
return torch.clip(betas, 0, 1.)
476+
return torch.clip(betas, 0, 0.999)
516477

517478
class GaussianDiffusion(nn.Module):
518479
def __init__(
@@ -645,8 +606,8 @@ def q_posterior(self, x_start, x_t, t):
645606
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
646607
return posterior_mean, posterior_variance, posterior_log_variance_clipped
647608

648-
def model_predictions(self, x, t, classes, cond_scale = 6., rescale_phi = 0.7, clip_x_start = False):
649-
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescale_phi = rescale_phi)
609+
def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False):
610+
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale)
650611
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
651612

652613
if self.objective == 'pred_noise':
@@ -678,7 +639,7 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
678639
return model_mean, posterior_variance, posterior_log_variance, x_start
679640

680641
@torch.no_grad()
681-
def p_sample(self, x, t: int, classes, cond_scale = 6., rescale_phi = 0.7, clip_denoised = True):
642+
def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
682643
b, *_, device = *x.shape, x.device
683644
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
684645
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, clip_denoised = clip_denoised)
@@ -687,7 +648,7 @@ def p_sample(self, x, t: int, classes, cond_scale = 6., rescale_phi = 0.7, clip_
687648
return pred_img, x_start
688649

689650
@torch.no_grad()
690-
def p_sample_loop(self, classes, shape, cond_scale = 6., rescale_phi = 0.7):
651+
def p_sample_loop(self, classes, shape, cond_scale = 3.):
691652
batch, device = shape[0], self.betas.device
692653

693654
img = torch.randn(shape, device=device)
@@ -701,7 +662,7 @@ def p_sample_loop(self, classes, shape, cond_scale = 6., rescale_phi = 0.7):
701662
return img
702663

703664
@torch.no_grad()
704-
def ddim_sample(self, classes, shape, cond_scale = 6., rescale_phi = 0.7, clip_denoised = True):
665+
def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
705666
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
706667

707668
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -765,6 +726,15 @@ def q_sample(self, x_start, t, noise=None):
765726
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
766727
)
767728

729+
@property
730+
def loss_fn(self):
731+
if self.loss_type == 'l1':
732+
return F.l1_loss
733+
elif self.loss_type == 'l2':
734+
return F.mse_loss
735+
else:
736+
raise ValueError(f'invalid loss type {self.loss_type}')
737+
768738
def p_losses(self, x_start, t, *, classes, noise = None):
769739
b, c, h, w = x_start.shape
770740
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -787,7 +757,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
787757
else:
788758
raise ValueError(f'unknown objective {self.objective}')
789759

790-
loss = F.mse_loss(model_out, target, reduction = 'none')
760+
loss = self.loss_fn(model_out, target, reduction = 'none')
791761
loss = reduce(loss, 'b ... -> b (...)', 'mean')
792762

793763
loss = loss * extract(self.loss_weight, t, loss.shape)

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
*,
117117
image_size,
118118
channels = 3,
119+
loss_type = 'l1',
119120
noise_schedule = 'linear',
120121
num_sample_steps = 500,
121122
clip_sample_denoised = True,
@@ -137,6 +138,8 @@ def __init__(
137138

138139
# continuous noise schedule related stuff
139140

141+
self.loss_type = loss_type
142+
140143
if noise_schedule == 'linear':
141144
self.log_snr = beta_linear_log_snr
142145
elif noise_schedule == 'cosine':
@@ -167,6 +170,15 @@ def __init__(
167170
def device(self):
168171
return next(self.model.parameters()).device
169172

173+
@property
174+
def loss_fn(self):
175+
if self.loss_type == 'l1':
176+
return F.l1_loss
177+
elif self.loss_type == 'l2':
178+
return F.mse_loss
179+
else:
180+
raise ValueError(f'invalid loss type {self.loss_type}')
181+
170182
def p_mean_variance(self, x, time, time_next):
171183
# reviewer found an error in the equation in the paper (missing sigma)
172184
# following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt
@@ -254,7 +266,7 @@ def p_losses(self, x_start, times, noise = None):
254266
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
255267
model_out = self.model(x, log_snr)
256268

257-
losses = F.mse_loss(model_out, noise, reduction = 'none')
269+
losses = self.loss_fn(model_out, noise, reduction = 'none')
258270
losses = reduce(losses, 'b ... -> b', 'mean')
259271

260272
if self.min_snr_loss_weight:

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from pathlib import Path
44
from random import random
5-
from functools import partial, wraps
5+
from functools import partial
66
from collections import namedtuple
77
from multiprocessing import cpu_count
88

@@ -68,11 +68,6 @@ def convert_image_to_fn(img_type, image):
6868
return image.convert(img_type)
6969
return image
7070

71-
def extract(a, t, x_shape):
72-
b, *_ = t.shape
73-
out = a.gather(-1, t)
74-
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
75-
7671
# normalization functions
7772

7873
def normalize_to_neg_one_to_one(img):
@@ -403,35 +398,13 @@ def forward(self, x, time, x_self_cond = None):
403398
x = self.final_res_block(x, t)
404399
return self.final_conv(x)
405400

406-
# scheduling functions
407-
408-
def enforce_zero_terminal_snr(schedule_fn):
409-
# algorithm 1 in https://arxiv.org/abs/2305.08891
410-
411-
@wraps(schedule_fn)
412-
def inner(*args, **kwargs):
413-
betas = schedule_fn(*args, **kwargs)
414-
alphas = 1. - betas
415-
416-
alphas_cumprod = torch.cumprod(alphas, dim = 0)
417-
alphas_cumprod = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
418-
419-
alphas_cumprod_sqrt = torch.sqrt(alphas_cumprod)
420-
421-
terminal_snr = alphas_cumprod_sqrt[-1].clone()
422-
423-
alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
424-
alphas_cumprod_sqrt *= 1. / (1. - terminal_snr)
425-
426-
alphas_cumprod = alphas_cumprod_sqrt ** 2
427-
alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
428-
betas = 1. - alphas
429-
430-
return betas
401+
# gaussian diffusion trainer class
431402

432-
return inner
403+
def extract(a, t, x_shape):
404+
b, *_ = t.shape
405+
out = a.gather(-1, t)
406+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
433407

434-
@enforce_zero_terminal_snr
435408
def linear_beta_schedule(timesteps):
436409
"""
437410
linear schedule, proposed in original ddpm paper
@@ -453,7 +426,6 @@ def cosine_beta_schedule(timesteps, s = 0.008):
453426
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
454427
return torch.clip(betas, 0, 1.)
455428

456-
@enforce_zero_terminal_snr
457429
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
458430
"""
459431
sigmoid schedule
@@ -469,8 +441,6 @@ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1
469441
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
470442
return torch.clip(betas, 0, 1.)
471443

472-
# gaussian diffusion trainer class
473-
474444
class GaussianDiffusion(nn.Module):
475445
def __init__(
476446
self,
@@ -479,6 +449,7 @@ def __init__(
479449
image_size,
480450
timesteps = 1000,
481451
sampling_timesteps = None,
452+
loss_type = 'l1',
482453
objective = 'pred_noise',
483454
beta_schedule = 'sigmoid',
484455
schedule_fn_kwargs = dict(),
@@ -502,9 +473,7 @@ def __init__(
502473

503474
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
504475

505-
if callable(beta_schedule):
506-
beta_schedule_fn = beta_schedule
507-
elif beta_schedule == 'linear':
476+
if beta_schedule == 'linear':
508477
beta_schedule_fn = linear_beta_schedule
509478
elif beta_schedule == 'cosine':
510479
beta_schedule_fn = cosine_beta_schedule
@@ -516,11 +485,12 @@ def __init__(
516485
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
517486

518487
alphas = 1. - betas
519-
alphas_cumprod = torch.cumprod(alphas, dim = 0)
488+
alphas_cumprod = torch.cumprod(alphas, dim=0)
520489
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
521490

522491
timesteps, = betas.shape
523492
self.num_timesteps = int(timesteps)
493+
self.loss_type = loss_type
524494

525495
# sampling related parameters
526496

@@ -541,10 +511,6 @@ def __init__(
541511
# calculations for diffusion q(x_t | x_{t-1}) and others
542512

543513
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
544-
545-
terminal_snr = self.sqrt_alphas_cumprod[-1]
546-
assert terminal_snr == 0, f'non-zero terminal SNR detected ({terminal_snr:.6f}), from https://arxiv.org/abs/2305.08891 paper - you can wrap your schedule function with `enforce_zero_terminal_snr` decorator'
547-
548514
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
549515
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
550516
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
@@ -759,6 +725,15 @@ def q_sample(self, x_start, t, noise=None):
759725
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
760726
)
761727

728+
@property
729+
def loss_fn(self):
730+
if self.loss_type == 'l1':
731+
return F.l1_loss
732+
elif self.loss_type == 'l2':
733+
return F.mse_loss
734+
else:
735+
raise ValueError(f'invalid loss type {self.loss_type}')
736+
762737
def p_losses(self, x_start, t, noise = None):
763738
b, c, h, w = x_start.shape
764739
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -791,7 +766,7 @@ def p_losses(self, x_start, t, noise = None):
791766
else:
792767
raise ValueError(f'unknown objective {self.objective}')
793768

794-
loss = F.mse_loss(model_out, target, reduction = 'none')
769+
loss = self.loss_fn(model_out, target, reduction = 'none')
795770
loss = reduce(loss, 'b ... -> b (...)', 'mean')
796771

797772
loss = loss * extract(self.loss_weight, t, loss.shape)

0 commit comments

Comments
 (0)