Skip to content

Commit 7897ddb

Browse files
committed
add the new sigmoid schedule, purportedly more stable and better for larger images than cosine schedule, from https://arxiv.org/abs/2212.11972 Figure 8
1 parent b07ba38 commit 7897ddb

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,11 @@ sampled_seq.shape # (4, 32, 128)
258258
volume = {abs/2208.03641}
259259
}
260260
```
261+
262+
```bibtex
263+
@inproceedings{Jabri2022ScalableAC,
264+
title = {Scalable Adaptive Computation for Iterative Generation},
265+
author = {A. Jabri and David J. Fleet and Ting Chen},
266+
year = {2022}
267+
}
268+
```

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ def extract(a, t, x_shape):
401401
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
402402

403403
def linear_beta_schedule(timesteps):
404+
"""
405+
linear schedule, proposed in original ddpm paper
406+
"""
404407
scale = 1000 / timesteps
405408
beta_start = scale * 0.0001
406409
beta_end = scale * 0.02
@@ -412,8 +415,23 @@ def cosine_beta_schedule(timesteps, s = 0.008):
412415
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
413416
"""
414417
steps = timesteps + 1
415-
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
416-
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
418+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
419+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
420+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
421+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
422+
return torch.clip(betas, 0, 0.999)
423+
424+
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
425+
"""
426+
sigmoid schedule
427+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
428+
better for images > 64x64, when used during training
429+
"""
430+
steps = timesteps + 1
431+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
432+
v_start = torch.tensor(start / tau).sigmoid()
433+
v_end = torch.tensor(end / tau).sigmoid()
434+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
417435
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
418436
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
419437
return torch.clip(betas, 0, 0.999)
@@ -429,6 +447,7 @@ def __init__(
429447
loss_type = 'l1',
430448
objective = 'pred_noise',
431449
beta_schedule = 'cosine',
450+
schedule_fn_kwargs = dict(),
432451
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
433452
p2_loss_weight_k = 1,
434453
ddim_sampling_eta = 0.
@@ -448,12 +467,16 @@ def __init__(
448467
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
449468

450469
if beta_schedule == 'linear':
451-
betas = linear_beta_schedule(timesteps)
470+
beta_schedule_fn = linear_beta_schedule
452471
elif beta_schedule == 'cosine':
453-
betas = cosine_beta_schedule(timesteps)
472+
beta_schedule_fn = cosine_beta_schedule
473+
elif beta_schedule == 'sigmoid':
474+
beta_schedule_fn = sigmoid_beta_schedule
454475
else:
455476
raise ValueError(f'unknown beta schedule {beta_schedule}')
456477

478+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
479+
457480
alphas = 1. - betas
458481
alphas_cumprod = torch.cumprod(alphas, dim=0)
459482
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.6'
1+
__version__ = '0.1.7'

0 commit comments

Comments
 (0)