Skip to content

Commit 6219bf7

Browse files
committed
adopt offset noise from Nicholas, adopt rescaled cfg from Bytedance
1 parent 1ac4d0d commit 6219bf7

9 files changed

+78
-91
lines changed

README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ model = Unet(
3838
diffusion = GaussianDiffusion(
3939
model,
4040
image_size = 128,
41-
timesteps = 1000, # number of steps
42-
loss_type = 'l1' # L1 or L2
41+
timesteps = 1000 # number of steps
4342
)
4443

4544
training_images = torch.rand(8, 3, 128, 128) # images are normalized from 0 to 1
@@ -65,8 +64,7 @@ diffusion = GaussianDiffusion(
6564
model,
6665
image_size = 128,
6766
timesteps = 1000, # number of steps
68-
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
69-
loss_type = 'l1' # L1 or L2
67+
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
7068
)
7169

7270
trainer = Trainer(
@@ -311,3 +309,18 @@ You could consider adding a suitable metric to the training loop yourself after
311309
year = {2023}
312310
}
313311
```
312+
313+
```bibtex
314+
@misc{Guttenberg2023,
315+
author = {Nicholas Guttenberg},
316+
url = {https://www.crosslabs.org/blog/diffusion-with-offset-noise}
317+
}
318+
```
319+
320+
```bibtex
321+
@inproceedings{Lin2023CommonDN,
322+
title = {Common Diffusion Noise Schedules and Sample Steps are Flawed},
323+
author = {Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang},
324+
year = {2023}
325+
}
326+
```

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def forward_with_cond_scale(
375375
self,
376376
*args,
377377
cond_scale = 1.,
378+
rescaled_phi = 0.,
378379
**kwargs
379380
):
380381
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
@@ -383,7 +384,15 @@ def forward_with_cond_scale(
383384
return logits
384385

385386
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
386-
return null_logits + (logits - null_logits) * cond_scale
387+
scaled_logits = null_logits + (logits - null_logits) * cond_scale
388+
389+
if rescaled_phi == 0.:
390+
return scaled_logits
391+
392+
std_fn = partial(torch.std, dim = tuple(range(1, scaled_logits.ndim)), keepdim = True)
393+
rescaled_logits = scaled_logits * (std_fn(logits) / std_fn(scaled_logits))
394+
395+
return rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi)
387396

388397
def forward(
389398
self,
@@ -483,10 +492,10 @@ def __init__(
483492
image_size,
484493
timesteps = 1000,
485494
sampling_timesteps = None,
486-
loss_type = 'l1',
487495
objective = 'pred_noise',
488496
beta_schedule = 'cosine',
489497
ddim_sampling_eta = 1.,
498+
offset_noise_strength = 0.,
490499
min_snr_loss_weight = False,
491500
min_snr_gamma = 5
492501
):
@@ -516,7 +525,6 @@ def __init__(
516525

517526
timesteps, = betas.shape
518527
self.num_timesteps = int(timesteps)
519-
self.loss_type = loss_type
520528

521529
# sampling related parameters
522530

@@ -556,6 +564,10 @@ def __init__(
556564
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
557565
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
558566

567+
# offset noise strength - 0.1 was claimed ideal
568+
569+
self.offset_noise_strength = offset_noise_strength
570+
559571
# loss weight
560572

561573
snr = alphas_cumprod / (1 - alphas_cumprod)
@@ -606,8 +618,8 @@ def q_posterior(self, x_start, x_t, t):
606618
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
607619
return posterior_mean, posterior_variance, posterior_log_variance_clipped
608620

609-
def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False):
610-
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale)
621+
def model_predictions(self, x, t, classes, cond_scale = 6., rescaled_phi = 0.7, clip_x_start = False):
622+
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi)
611623
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
612624

613625
if self.objective == 'pred_noise':
@@ -628,8 +640,8 @@ def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False
628640

629641
return ModelPrediction(pred_noise, x_start)
630642

631-
def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
632-
preds = self.model_predictions(x, t, classes, cond_scale)
643+
def p_mean_variance(self, x, t, classes, cond_scale, rescaled_phi, clip_denoised = True):
644+
preds = self.model_predictions(x, t, classes, cond_scale, rescaled_phi)
633645
x_start = preds.pred_x_start
634646

635647
if clip_denoised:
@@ -639,30 +651,30 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
639651
return model_mean, posterior_variance, posterior_log_variance, x_start
640652

641653
@torch.no_grad()
642-
def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
654+
def p_sample(self, x, t: int, classes, cond_scale = 6., rescaled_phi = 0.7, clip_denoised = True):
643655
b, *_, device = *x.shape, x.device
644656
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
645-
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, clip_denoised = clip_denoised)
657+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, classes = classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi, clip_denoised = clip_denoised)
646658
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
647659
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
648660
return pred_img, x_start
649661

650662
@torch.no_grad()
651-
def p_sample_loop(self, classes, shape, cond_scale = 3.):
663+
def p_sample_loop(self, classes, shape, cond_scale = 6., rescaled_phi = 0.7):
652664
batch, device = shape[0], self.betas.device
653665

654666
img = torch.randn(shape, device=device)
655667

656668
x_start = None
657669

658670
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
659-
img, x_start = self.p_sample(img, t, classes, cond_scale)
671+
img, x_start = self.p_sample(img, t, classes, cond_scale, rescaled_phi)
660672

661673
img = unnormalize_to_zero_to_one(img)
662674
return img
663675

664676
@torch.no_grad()
665-
def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
677+
def ddim_sample(self, classes, shape, cond_scale = 6., rescaled_phi = 0.7, clip_denoised = True):
666678
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
667679

668680
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -697,10 +709,10 @@ def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
697709
return img
698710

699711
@torch.no_grad()
700-
def sample(self, classes, cond_scale = 3.):
712+
def sample(self, classes, cond_scale = 6., rescaled_phi = 0.7):
701713
batch_size, image_size, channels = classes.shape[0], self.image_size, self.channels
702714
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
703-
return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale)
715+
return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale, rescaled_phi)
704716

705717
@torch.no_grad()
706718
def interpolate(self, x1, x2, t = None, lam = 0.5):
@@ -721,20 +733,15 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
721733
def q_sample(self, x_start, t, noise=None):
722734
noise = default(noise, lambda: torch.randn_like(x_start))
723735

736+
if self.offset_noise_strength > 0.:
737+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
738+
noise += self.offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
739+
724740
return (
725741
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
726742
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
727743
)
728744

729-
@property
730-
def loss_fn(self):
731-
if self.loss_type == 'l1':
732-
return F.l1_loss
733-
elif self.loss_type == 'l2':
734-
return F.mse_loss
735-
else:
736-
raise ValueError(f'invalid loss type {self.loss_type}')
737-
738745
def p_losses(self, x_start, t, *, classes, noise = None):
739746
b, c, h, w = x_start.shape
740747
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -757,7 +764,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
757764
else:
758765
raise ValueError(f'unknown objective {self.objective}')
759766

760-
loss = self.loss_fn(model_out, target, reduction = 'none')
767+
loss = F.mse_loss(model_out, target, reduction = 'none')
761768
loss = reduce(loss, 'b ... -> b (...)', 'mean')
762769

763770
loss = loss * extract(self.loss_weight, t, loss.shape)
@@ -799,7 +806,7 @@ def forward(self, img, *args, **kwargs):
799806

800807
sampled_images = diffusion.sample(
801808
classes = image_classes,
802-
cond_scale = 3. # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
809+
cond_scale = 6. # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
803810
)
804811

805812
sampled_images.shape # (8, 3, 128, 128)

denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def __init__(
116116
*,
117117
image_size,
118118
channels = 3,
119-
loss_type = 'l1',
120119
noise_schedule = 'linear',
121120
num_sample_steps = 500,
122121
clip_sample_denoised = True,
@@ -138,8 +137,6 @@ def __init__(
138137

139138
# continuous noise schedule related stuff
140139

141-
self.loss_type = loss_type
142-
143140
if noise_schedule == 'linear':
144141
self.log_snr = beta_linear_log_snr
145142
elif noise_schedule == 'cosine':
@@ -170,15 +167,6 @@ def __init__(
170167
def device(self):
171168
return next(self.model.parameters()).device
172169

173-
@property
174-
def loss_fn(self):
175-
if self.loss_type == 'l1':
176-
return F.l1_loss
177-
elif self.loss_type == 'l2':
178-
return F.mse_loss
179-
else:
180-
raise ValueError(f'invalid loss type {self.loss_type}')
181-
182170
def p_mean_variance(self, x, time, time_next):
183171
# reviewer found an error in the equation in the paper (missing sigma)
184172
# following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt
@@ -266,7 +254,7 @@ def p_losses(self, x_start, times, noise = None):
266254
x, log_snr = self.q_sample(x_start = x_start, times = times, noise = noise)
267255
model_out = self.model(x, log_snr)
268256

269-
losses = self.loss_fn(model_out, noise, reduction = 'none')
257+
losses = F.mse_loss(model_out, noise, reduction = 'none')
270258
losses = reduce(losses, 'b ... -> b', 'mean')
271259

272260
if self.min_snr_loss_weight:

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,12 @@ def __init__(
449449
image_size,
450450
timesteps = 1000,
451451
sampling_timesteps = None,
452-
loss_type = 'l1',
453-
objective = 'pred_noise',
452+
objective = 'pred_v',
454453
beta_schedule = 'sigmoid',
455454
schedule_fn_kwargs = dict(),
456455
ddim_sampling_eta = 0.,
457456
auto_normalize = True,
457+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
458458
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
459459
min_snr_gamma = 5
460460
):
@@ -490,7 +490,6 @@ def __init__(
490490

491491
timesteps, = betas.shape
492492
self.num_timesteps = int(timesteps)
493-
self.loss_type = loss_type
494493

495494
# sampling related parameters
496495

@@ -530,6 +529,10 @@ def __init__(
530529
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
531530
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
532531

532+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
533+
534+
self.offset_noise_strength = offset_noise_strength
535+
533536
# derive loss weight
534537
# snr - signal noise ratio
535538

@@ -553,6 +556,10 @@ def __init__(
553556
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
554557
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
555558

559+
@property
560+
def device(self):
561+
return self.betas.device
562+
556563
def predict_start_from_noise(self, x_t, t, noise):
557564
return (
558565
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
@@ -623,16 +630,16 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
623630

624631
@torch.no_grad()
625632
def p_sample(self, x, t: int, x_self_cond = None):
626-
b, *_, device = *x.shape, x.device
627-
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
633+
b, *_, device = *x.shape, self.device
634+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
628635
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
629636
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
630637
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
631638
return pred_img, x_start
632639

633640
@torch.no_grad()
634641
def p_sample_loop(self, shape, return_all_timesteps = False):
635-
batch, device = shape[0], self.betas.device
642+
batch, device = shape[0], self.device
636643

637644
img = torch.randn(shape, device = device)
638645
imgs = [img]
@@ -651,7 +658,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False):
651658

652659
@torch.no_grad()
653660
def ddim_sample(self, shape, return_all_timesteps = False):
654-
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
661+
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
655662

656663
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
657664
times = list(reversed(times.int().tolist()))
@@ -717,23 +724,18 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
717724

718725
return img
719726

720-
def q_sample(self, x_start, t, noise=None):
727+
def q_sample(self, x_start, t, noise = None):
721728
noise = default(noise, lambda: torch.randn_like(x_start))
722729

730+
if self.offset_noise_strength > 0.:
731+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
732+
noise += self.offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
733+
723734
return (
724735
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
725736
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
726737
)
727738

728-
@property
729-
def loss_fn(self):
730-
if self.loss_type == 'l1':
731-
return F.l1_loss
732-
elif self.loss_type == 'l2':
733-
return F.mse_loss
734-
else:
735-
raise ValueError(f'invalid loss type {self.loss_type}')
736-
737739
def p_losses(self, x_start, t, noise = None):
738740
b, c, h, w = x_start.shape
739741
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -766,7 +768,7 @@ def p_losses(self, x_start, t, noise = None):
766768
else:
767769
raise ValueError(f'unknown objective {self.objective}')
768770

769-
loss = self.loss_fn(model_out, target, reduction = 'none')
771+
loss = F.mse_loss(model_out, target, reduction = 'none')
770772
loss = reduce(loss, 'b ... -> b (...)', 'mean')
771773

772774
loss = loss * extract(self.loss_weight, t, loss.shape)

0 commit comments

Comments
 (0)