@@ -574,8 +574,12 @@ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rede
574574
575575 return ModelPrediction (pred_noise , x_start )
576576
577- def p_mean_variance (self , x , t , x_self_cond = None , clip_denoised = True ):
578- preds = self .model_predictions (x , t , x_self_cond )
577+ def p_mean_variance (self , x , t , x_self_cond = None , clip_denoised = True , model_forward_kwargs : dict = dict ()):
578+
579+ if exists (x_self_cond ):
580+ model_forward_kwargs = {** model_forward_kwargs , 'self_cond' : x_self_cond }
581+
582+ preds = self .model_predictions (x , t , ** model_forward_kwargs )
579583 x_start = preds .pred_x_start
580584
581585 if clip_denoised :
@@ -585,38 +589,44 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
585589 return model_mean , posterior_variance , posterior_log_variance , x_start
586590
587591 @torch .no_grad ()
588- def p_sample (self , x , t : int , x_self_cond = None , clip_denoised = True ):
592+ def p_sample (self , x , t : int , x_self_cond = None , clip_denoised = True , model_forward_kwargs : dict = dict () ):
589593 b , * _ , device = * x .shape , x .device
590594 batched_times = torch .full ((b ,), t , device = x .device , dtype = torch .long )
591- model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , x_self_cond = x_self_cond , clip_denoised = clip_denoised )
595+ model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , x_self_cond = x_self_cond , clip_denoised = clip_denoised , model_forward_kwargs = model_forward_kwargs )
592596 noise = torch .randn_like (x ) if t > 0 else 0. # no noise if t == 0
593597 pred_img = model_mean + (0.5 * model_log_variance ).exp () * noise
594598 return pred_img , x_start
595599
596600 @torch .no_grad ()
597- def p_sample_loop (self , shape ):
601+ def p_sample_loop (self , shape , return_noise = False , model_forward_kwargs : dict = dict () ):
598602 batch , device = shape [0 ], self .betas .device
599603
600- img = torch .randn (shape , device = device )
604+ noise = torch .randn (shape , device = device )
605+ img = noise
601606
602607 x_start = None
603608
604609 for t in tqdm (reversed (range (0 , self .num_timesteps )), desc = 'sampling loop time step' , total = self .num_timesteps ):
605610 self_cond = x_start if self .self_condition else None
606- img , x_start = self .p_sample (img , t , self_cond )
611+ img , x_start = self .p_sample (img , t , self_cond , model_forward_kwargs = model_forward_kwargs )
607612
608613 img = self .unnormalize (img )
609- return img
614+
615+ if not return_noise :
616+ return img
617+
618+ return img , noise
610619
611620 @torch .no_grad ()
612- def ddim_sample (self , shape , clip_denoised = True , model_forward_kwargs : dict = dict ()):
621+ def ddim_sample (self , shape , clip_denoised = True , model_forward_kwargs : dict = dict (), return_noise = False ):
613622 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
614623
615624 times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
616625 times = list (reversed (times .int ().tolist ()))
617626 time_pairs = list (zip (times [:- 1 ], times [1 :])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
618627
619- img = torch .randn (shape , device = device )
628+ noise = torch .randn (shape , device = device )
629+ img = noise
620630
621631 x_start = None
622632
@@ -642,15 +652,19 @@ def ddim_sample(self, shape, clip_denoised = True, model_forward_kwargs: dict =
642652 sigma * noise
643653
644654 img = self .unnormalize (img )
645- return img
655+
656+ if not return_noise :
657+ return img
658+
659+ return img , noise
646660
647661 @torch .no_grad ()
648- def sample (self , batch_size = 16 , model_forward_kwargs : dict = dict ()):
662+ def sample (self , batch_size = 16 , return_noise = False , model_forward_kwargs : dict = dict ()):
649663 seq_length , channels = self .seq_length , self .channels
650664 sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
651665
652666 shape = (batch_size , channels , seq_length ) if self .channel_first else (batch_size , seq_length , channels )
653- return sample_fn (shape , model_forward_kwargs = model_forward_kwargs )
667+ return sample_fn (shape , return_noise = return_noise , model_forward_kwargs = model_forward_kwargs )
654668
655669 @torch .no_grad ()
656670 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
0 commit comments