22import copy
33from pathlib import Path
44from random import random
5- from functools import partial , wraps
5+ from functools import partial
66from collections import namedtuple
77from multiprocessing import cpu_count
88
@@ -375,7 +375,6 @@ def forward_with_cond_scale(
375375 self ,
376376 * args ,
377377 cond_scale = 1. ,
378- rescale_phi = 0. ,
379378 ** kwargs
380379 ):
381380 logits = self .forward (* args , cond_drop_prob = 0. , ** kwargs )
@@ -384,18 +383,7 @@ def forward_with_cond_scale(
384383 return logits
385384
386385 null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
387- scaled_logits = null_logits + (logits - null_logits ) * cond_scale
388-
389- if rescale_phi <= 0 :
390- return scaled_logits
391-
392- # rescaling proposed in https://arxiv.org/abs/2305.08891 to prevent over-saturation
393- # works for both pixel and latent space, as opposed to only pixel space with the technique from Imagen
394- # they found 0.7 to work well empirically with a conditional scale of 6.
395-
396- std_fn = partial (torch .std , dim = tuple (range (1 , scaled_logits .ndim )), keepdim = True )
397- rescaled_logits = scaled_logits * (std_fn (logits ) / std_fn (scaled_logits ))
398- return rescaled_logits * rescale_phi + (1 - rescale_phi ) * scaled_logits
386+ return null_logits + (logits - null_logits ) * cond_scale
399387
400388 def forward (
401389 self ,
@@ -469,33 +457,6 @@ def extract(a, t, x_shape):
469457 out = a .gather (- 1 , t )
470458 return out .reshape (b , * ((1 ,) * (len (x_shape ) - 1 )))
471459
472- def enforce_zero_terminal_snr (schedule_fn ):
473- # algorithm 1 in https://arxiv.org/abs/2305.08891
474-
475- @wraps (schedule_fn )
476- def inner (* args , ** kwargs ):
477- betas = schedule_fn (* args , ** kwargs )
478- alphas = 1. - betas
479-
480- alphas_cumprod = torch .cumprod (alphas , dim = 0 )
481- alphas_cumprod = F .pad (alphas_cumprod [:- 1 ], (1 , 0 ), value = 1. )
482-
483- alphas_cumprod_sqrt = torch .sqrt (alphas_cumprod )
484-
485- terminal_snr = alphas_cumprod_sqrt [- 1 ].clone ()
486-
487- alphas_cumprod_sqrt -= terminal_snr # enforce zero terminal snr
488- alphas_cumprod_sqrt *= 1. / (1. - terminal_snr )
489-
490- alphas_cumprod = alphas_cumprod_sqrt ** 2
491- alphas = alphas_cumprod [1 :] / alphas_cumprod [:- 1 ]
492- betas = 1. - alphas
493-
494- return betas
495-
496- return inner
497-
498- @enforce_zero_terminal_snr
499460def linear_beta_schedule (timesteps ):
500461 scale = 1000 / timesteps
501462 beta_start = scale * 0.0001
@@ -512,7 +473,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
512473 alphas_cumprod = torch .cos (((x / timesteps ) + s ) / (1 + s ) * math .pi * 0.5 ) ** 2
513474 alphas_cumprod = alphas_cumprod / alphas_cumprod [0 ]
514475 betas = 1 - (alphas_cumprod [1 :] / alphas_cumprod [:- 1 ])
515- return torch .clip (betas , 0 , 1. )
476+ return torch .clip (betas , 0 , 0.999 )
516477
517478class GaussianDiffusion (nn .Module ):
518479 def __init__ (
@@ -645,8 +606,8 @@ def q_posterior(self, x_start, x_t, t):
645606 posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
646607 return posterior_mean , posterior_variance , posterior_log_variance_clipped
647608
648- def model_predictions (self , x , t , classes , cond_scale = 6. , rescale_phi = 0.7 , clip_x_start = False ):
649- model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescale_phi = rescale_phi )
609+ def model_predictions (self , x , t , classes , cond_scale = 3. , clip_x_start = False ):
610+ model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale )
650611 maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
651612
652613 if self .objective == 'pred_noise' :
@@ -678,7 +639,7 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
678639 return model_mean , posterior_variance , posterior_log_variance , x_start
679640
680641 @torch .no_grad ()
681- def p_sample (self , x , t : int , classes , cond_scale = 6. , rescale_phi = 0.7 , clip_denoised = True ):
642+ def p_sample (self , x , t : int , classes , cond_scale = 3. , clip_denoised = True ):
682643 b , * _ , device = * x .shape , x .device
683644 batched_times = torch .full ((x .shape [0 ],), t , device = x .device , dtype = torch .long )
684645 model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , classes = classes , cond_scale = cond_scale , clip_denoised = clip_denoised )
@@ -687,7 +648,7 @@ def p_sample(self, x, t: int, classes, cond_scale = 6., rescale_phi = 0.7, clip_
687648 return pred_img , x_start
688649
689650 @torch .no_grad ()
690- def p_sample_loop (self , classes , shape , cond_scale = 6. , rescale_phi = 0.7 ):
651+ def p_sample_loop (self , classes , shape , cond_scale = 3. ):
691652 batch , device = shape [0 ], self .betas .device
692653
693654 img = torch .randn (shape , device = device )
@@ -701,7 +662,7 @@ def p_sample_loop(self, classes, shape, cond_scale = 6., rescale_phi = 0.7):
701662 return img
702663
703664 @torch .no_grad ()
704- def ddim_sample (self , classes , shape , cond_scale = 6. , rescale_phi = 0.7 , clip_denoised = True ):
665+ def ddim_sample (self , classes , shape , cond_scale = 3. , clip_denoised = True ):
705666 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
706667
707668 times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -765,6 +726,15 @@ def q_sample(self, x_start, t, noise=None):
765726 extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
766727 )
767728
729+ @property
730+ def loss_fn (self ):
731+ if self .loss_type == 'l1' :
732+ return F .l1_loss
733+ elif self .loss_type == 'l2' :
734+ return F .mse_loss
735+ else :
736+ raise ValueError (f'invalid loss type { self .loss_type } ' )
737+
768738 def p_losses (self , x_start , t , * , classes , noise = None ):
769739 b , c , h , w = x_start .shape
770740 noise = default (noise , lambda : torch .randn_like (x_start ))
@@ -787,7 +757,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
787757 else :
788758 raise ValueError (f'unknown objective { self .objective } ' )
789759
790- loss = F . mse_loss (model_out , target , reduction = 'none' )
760+ loss = self . loss_fn (model_out , target , reduction = 'none' )
791761 loss = reduce (loss , 'b ... -> b (...)' , 'mean' )
792762
793763 loss = loss * extract (self .loss_weight , t , loss .shape )
0 commit comments