22import copy
33from pathlib import Path
44from random import random
5- from functools import partial
5+ from functools import partial , wraps
66from collections import namedtuple
77from multiprocessing import cpu_count
88
@@ -375,6 +375,7 @@ def forward_with_cond_scale(
375375 self ,
376376 * args ,
377377 cond_scale = 1. ,
378+ rescale_phi = 0. ,
378379 ** kwargs
379380 ):
380381 logits = self .forward (* args , cond_drop_prob = 0. , ** kwargs )
@@ -383,7 +384,18 @@ def forward_with_cond_scale(
383384 return logits
384385
385386 null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
386- return null_logits + (logits - null_logits ) * cond_scale
387+ scaled_logits = null_logits + (logits - null_logits ) * cond_scale
388+
389+ if rescale_phi <= 0 :
390+ return scaled_logits
391+
392+ # rescaling proposed in https://arxiv.org/abs/2305.08891 to prevent over-saturation
393+ # works for both pixel and latent space, as opposed to only pixel space with the technique from Imagen
394+ # they found 0.7 to work well empirically with a conditional scale of 6.
395+
396+ std_fn = partial (torch .std , dim = tuple (range (1 , scaled_logits .ndim )), keepdim = True )
397+ rescaled_logits = scaled_logits * (std_fn (logits ) / std_fn (scaled_logits ))
398+ return rescaled_logits * rescale_phi + (1 - rescale_phi ) * scaled_logits
387399
388400 def forward (
389401 self ,
@@ -457,6 +469,33 @@ def extract(a, t, x_shape):
457469 out = a .gather (- 1 , t )
458470 return out .reshape (b , * ((1 ,) * (len (x_shape ) - 1 )))
459471
472+ def enforce_zero_terminal_snr (schedule_fn ):
473+ # algorithm 1 in https://arxiv.org/abs/2305.08891
474+
475+ @wraps (schedule_fn )
476+ def inner (* args , ** kwargs ):
477+ betas = schedule_fn (* args , ** kwargs )
478+ alphas = 1. - betas
479+
480+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
481+ alphas_cumprod = F .pad (alphas_cumprod [:- 1 ], (1 , 0 ), value = 1. )
482+
483+ alphas_cumprod_sqrt = torch .sqrt (alphas_cumprod )
484+
485+ terminal_snr = alphas_cumprod_sqrt [- 1 ].clone ()
486+
487+ alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
488+ alphas_cumprod_sqrt *= 1. / (1. - terminal_snr )
489+
490+ alphas_cumprod = alphas_cumprod_sqrt ** 2
491+ alphas = alphas_cumprod [1 :] / alphas_cumprod [:- 1 ]
492+ betas = 1. - alphas
493+
494+ return betas
495+
496+ return inner
497+
498+ @enforce_zero_terminal_snr
460499def linear_beta_schedule (timesteps ):
461500 scale = 1000 / timesteps
462501 beta_start = scale * 0.0001
@@ -473,7 +512,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
473512 alphas_cumprod = torch .cos (((x / timesteps ) + s ) / (1 + s ) * math .pi * 0.5 ) ** 2
474513 alphas_cumprod = alphas_cumprod / alphas_cumprod [0 ]
475514 betas = 1 - (alphas_cumprod [1 :] / alphas_cumprod [:- 1 ])
476- return torch .clip (betas , 0 , 0.999 )
515+ return torch .clip (betas , 0 , 1. )
477516
478517class GaussianDiffusion (nn .Module ):
479518 def __init__ (
@@ -606,8 +645,8 @@ def q_posterior(self, x_start, x_t, t):
606645 posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
607646 return posterior_mean , posterior_variance , posterior_log_variance_clipped
608647
609- def model_predictions (self , x , t , classes , cond_scale = 3. , clip_x_start = False ):
610- model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale )
648+ def model_predictions (self , x , t , classes , cond_scale = 6. , rescale_phi = 0.7 , clip_x_start = False ):
649+ model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescale_phi = rescale_phi )
611650 maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
612651
613652 if self .objective == 'pred_noise' :
@@ -639,7 +678,7 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
639678 return model_mean , posterior_variance , posterior_log_variance , x_start
640679
641680 @torch .no_grad ()
642- def p_sample (self , x , t : int , classes , cond_scale = 3. , clip_denoised = True ):
681+ def p_sample (self , x , t : int , classes , cond_scale = 6. , rescale_phi = 0.7 , clip_denoised = True ):
643682 b , * _ , device = * x .shape , x .device
644683 batched_times = torch .full ((x .shape [0 ],), t , device = x .device , dtype = torch .long )
645684 model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , classes = classes , cond_scale = cond_scale , clip_denoised = clip_denoised )
@@ -648,7 +687,7 @@ def p_sample(self, x, t: int, classes, cond_scale = 3., clip_denoised = True):
648687 return pred_img , x_start
649688
650689 @torch .no_grad ()
651- def p_sample_loop (self , classes , shape , cond_scale = 3. ):
690+ def p_sample_loop (self , classes , shape , cond_scale = 6. , rescale_phi = 0.7 ):
652691 batch , device = shape [0 ], self .betas .device
653692
654693 img = torch .randn (shape , device = device )
@@ -662,7 +701,7 @@ def p_sample_loop(self, classes, shape, cond_scale = 3.):
662701 return img
663702
664703 @torch .no_grad ()
665- def ddim_sample (self , classes , shape , cond_scale = 3. , clip_denoised = True ):
704+ def ddim_sample (self , classes , shape , cond_scale = 6. , rescale_phi = 0.7 , clip_denoised = True ):
666705 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
667706
668707 times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -726,15 +765,6 @@ def q_sample(self, x_start, t, noise=None):
726765 extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
727766 )
728767
729- @property
730- def loss_fn (self ):
731- if self .loss_type == 'l1' :
732- return F .l1_loss
733- elif self .loss_type == 'l2' :
734- return F .mse_loss
735- else :
736- raise ValueError (f'invalid loss type { self .loss_type } ' )
737-
738768 def p_losses (self , x_start , t , * , classes , noise = None ):
739769 b , c , h , w = x_start .shape
740770 noise = default (noise , lambda : torch .randn_like (x_start ))
@@ -757,7 +787,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
757787 else :
758788 raise ValueError (f'unknown objective { self .objective } ' )
759789
760- loss = self . loss_fn (model_out , target , reduction = 'none' )
790+ loss = F . mse_loss (model_out , target , reduction = 'none' )
761791 loss = reduce (loss , 'b ... -> b (...)' , 'mean' )
762792
763793 loss = loss * extract (self .loss_weight , t , loss .shape )
0 commit comments