@@ -375,6 +375,7 @@ def forward_with_cond_scale(
375375 self ,
376376 * args ,
377377 cond_scale = 1. ,
378+ rescaled_phi = 0. ,
378379 ** kwargs
379380 ):
380381 logits = self .forward (* args , cond_drop_prob = 0. , ** kwargs )
@@ -383,7 +384,15 @@ def forward_with_cond_scale(
383384 return logits
384385
385386 null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
386- return null_logits + (logits - null_logits ) * cond_scale
387+ scaled_logits = null_logits + (logits - null_logits ) * cond_scale
388+
389+ if rescaled_phi == 0. :
390+ return scaled_logits
391+
392+ std_fn = partial (torch .std , dim = tuple (range (1 , scaled_logits .ndim )), keepdim = True )
393+ rescaled_logits = scaled_logits * (std_fn (logits ) / std_fn (scaled_logits ))
394+
395+ return rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi )
387396
388397 def forward (
389398 self ,
@@ -483,10 +492,10 @@ def __init__(
483492 image_size ,
484493 timesteps = 1000 ,
485494 sampling_timesteps = None ,
486- loss_type = 'l1' ,
487495 objective = 'pred_noise' ,
488496 beta_schedule = 'cosine' ,
489497 ddim_sampling_eta = 1. ,
498+ offset_noise_strength = 0. ,
490499 min_snr_loss_weight = False ,
491500 min_snr_gamma = 5
492501 ):
@@ -516,7 +525,6 @@ def __init__(
516525
517526 timesteps , = betas .shape
518527 self .num_timesteps = int (timesteps )
519- self .loss_type = loss_type
520528
521529 # sampling related parameters
522530
@@ -556,6 +564,10 @@ def __init__(
556564 register_buffer ('posterior_mean_coef1' , betas * torch .sqrt (alphas_cumprod_prev ) / (1. - alphas_cumprod ))
557565 register_buffer ('posterior_mean_coef2' , (1. - alphas_cumprod_prev ) * torch .sqrt (alphas ) / (1. - alphas_cumprod ))
558566
567+ # offset noise strength - 0.1 was claimed ideal
568+
569+ self .offset_noise_strength = offset_noise_strength
570+
559571 # loss weight
560572
561573 snr = alphas_cumprod / (1 - alphas_cumprod )
@@ -606,8 +618,8 @@ def q_posterior(self, x_start, x_t, t):
606618 posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
607619 return posterior_mean , posterior_variance , posterior_log_variance_clipped
608620
609- def model_predictions (self , x , t , classes , cond_scale = 3. , clip_x_start = False ):
610- model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale )
621+ def model_predictions (self , x , t , classes , cond_scale = 6. , rescaled_phi = 0.7 , clip_x_start = False ):
622+ model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi )
611623 maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
612624
613625 if self .objective == 'pred_noise' :
@@ -628,8 +640,8 @@ def model_predictions(self, x, t, classes, cond_scale = 3., clip_x_start = False
628640
629641 return ModelPrediction (pred_noise , x_start )
630642
631- def p_mean_variance (self , x , t , classes , cond_scale , clip_denoised = True ):
632- preds = self .model_predictions (x , t , classes , cond_scale )
643+ def p_mean_variance (self , x , t , classes , cond_scale , rescaled_phi , clip_denoised = True ):
644+ preds = self .model_predictions (x , t , classes , cond_scale , rescaled_phi )
633645 x_start = preds .pred_x_start
634646
635647 if clip_denoised :
@@ -639,30 +651,30 @@ def p_mean_variance(self, x, t, classes, cond_scale, clip_denoised = True):
639651 return model_mean , posterior_variance , posterior_log_variance , x_start
640652
641653 @torch .no_grad ()
642- def p_sample (self , x , t : int , classes , cond_scale = 3. , clip_denoised = True ):
654+ def p_sample (self , x , t : int , classes , cond_scale = 6. , rescaled_phi = 0.7 , clip_denoised = True ):
643655 b , * _ , device = * x .shape , x .device
644656 batched_times = torch .full ((x .shape [0 ],), t , device = x .device , dtype = torch .long )
645- model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , classes = classes , cond_scale = cond_scale , clip_denoised = clip_denoised )
657+ model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , classes = classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi , clip_denoised = clip_denoised )
646658 noise = torch .randn_like (x ) if t > 0 else 0. # no noise if t == 0
647659 pred_img = model_mean + (0.5 * model_log_variance ).exp () * noise
648660 return pred_img , x_start
649661
650662 @torch .no_grad ()
651- def p_sample_loop (self , classes , shape , cond_scale = 3. ):
663+ def p_sample_loop (self , classes , shape , cond_scale = 6. , rescaled_phi = 0.7 ):
652664 batch , device = shape [0 ], self .betas .device
653665
654666 img = torch .randn (shape , device = device )
655667
656668 x_start = None
657669
658670 for t in tqdm (reversed (range (0 , self .num_timesteps )), desc = 'sampling loop time step' , total = self .num_timesteps ):
659- img , x_start = self .p_sample (img , t , classes , cond_scale )
671+ img , x_start = self .p_sample (img , t , classes , cond_scale , rescaled_phi )
660672
661673 img = unnormalize_to_zero_to_one (img )
662674 return img
663675
664676 @torch .no_grad ()
665- def ddim_sample (self , classes , shape , cond_scale = 3. , clip_denoised = True ):
677+ def ddim_sample (self , classes , shape , cond_scale = 6. , rescaled_phi = 0.7 , clip_denoised = True ):
666678 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
667679
668680 times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -697,10 +709,10 @@ def ddim_sample(self, classes, shape, cond_scale = 3., clip_denoised = True):
697709 return img
698710
699711 @torch .no_grad ()
700- def sample (self , classes , cond_scale = 3. ):
712+ def sample (self , classes , cond_scale = 6. , rescaled_phi = 0.7 ):
701713 batch_size , image_size , channels = classes .shape [0 ], self .image_size , self .channels
702714 sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
703- return sample_fn (classes , (batch_size , channels , image_size , image_size ), cond_scale )
715+ return sample_fn (classes , (batch_size , channels , image_size , image_size ), cond_scale , rescaled_phi )
704716
705717 @torch .no_grad ()
706718 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
@@ -721,20 +733,15 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
721733 def q_sample (self , x_start , t , noise = None ):
722734 noise = default (noise , lambda : torch .randn_like (x_start ))
723735
736+ if self .offset_noise_strength > 0. :
737+ offset_noise = torch .randn (x_start .shape [:2 ], device = self .device )
738+ noise += self .offset_noise_strength * rearrange (offset_noise , 'b c -> b c 1 1' )
739+
724740 return (
725741 extract (self .sqrt_alphas_cumprod , t , x_start .shape ) * x_start +
726742 extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
727743 )
728744
729- @property
730- def loss_fn (self ):
731- if self .loss_type == 'l1' :
732- return F .l1_loss
733- elif self .loss_type == 'l2' :
734- return F .mse_loss
735- else :
736- raise ValueError (f'invalid loss type { self .loss_type } ' )
737-
738745 def p_losses (self , x_start , t , * , classes , noise = None ):
739746 b , c , h , w = x_start .shape
740747 noise = default (noise , lambda : torch .randn_like (x_start ))
@@ -757,7 +764,7 @@ def p_losses(self, x_start, t, *, classes, noise = None):
757764 else :
758765 raise ValueError (f'unknown objective { self .objective } ' )
759766
760- loss = self . loss_fn (model_out , target , reduction = 'none' )
767+ loss = F . mse_loss (model_out , target , reduction = 'none' )
761768 loss = reduce (loss , 'b ... -> b (...)' , 'mean' )
762769
763770 loss = loss * extract (self .loss_weight , t , loss .shape )
@@ -799,7 +806,7 @@ def forward(self, img, *args, **kwargs):
799806
800807 sampled_images = diffusion .sample (
801808 classes = image_classes ,
802- cond_scale = 3 . # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
809+ cond_scale = 6 . # condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
803810 )
804811
805812 sampled_images .shape # (8, 3, 128, 128)
0 commit comments