@@ -545,8 +545,12 @@ def q_posterior(self, x_start, x_t, t):
545545 posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
546546 return posterior_mean , posterior_variance , posterior_log_variance_clipped
547547
548- def model_predictions (self , x , t , x_self_cond = None , clip_x_start = False , rederive_pred_noise = False ):
549- model_output = self .model (x , t , x_self_cond )
548+ def model_predictions (self , x , t , x_self_cond = None , clip_x_start = False , rederive_pred_noise = False , model_forward_kwargs : dict = dict ()):
549+
550+ if exists (x_self_cond ):
551+ model_forward_kwargs = {** model_forward_kwargs , 'self_cond' : x_self_cond }
552+
553+ model_output = self .model (x , t , ** model_forward_kwargs )
550554 maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
551555
552556 if self .objective == 'pred_noise' :
@@ -605,7 +609,7 @@ def p_sample_loop(self, shape):
605609 return img
606610
607611 @torch .no_grad ()
608- def ddim_sample (self , shape , clip_denoised = True ):
612+ def ddim_sample (self , shape , clip_denoised = True , model_forward_kwargs : dict = dict () ):
609613 batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
610614
611615 times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -619,7 +623,7 @@ def ddim_sample(self, shape, clip_denoised = True):
619623 for time , time_next in tqdm (time_pairs , desc = 'sampling loop time step' ):
620624 time_cond = torch .full ((batch ,), time , device = device , dtype = torch .long )
621625 self_cond = x_start if self .self_condition else None
622- pred_noise , x_start , * _ = self .model_predictions (img , time_cond , self_cond , clip_x_start = clip_denoised )
626+ pred_noise , x_start , * _ = self .model_predictions (img , time_cond , self_cond , clip_x_start = clip_denoised , model_forward_kwargs = model_forward_kwargs )
623627
624628 if time_next < 0 :
625629 img = x_start
@@ -641,12 +645,12 @@ def ddim_sample(self, shape, clip_denoised = True):
641645 return img
642646
643647 @torch .no_grad ()
644- def sample (self , batch_size = 16 ):
648+ def sample (self , batch_size = 16 , model_forward_kwargs : dict = dict () ):
645649 seq_length , channels = self .seq_length , self .channels
646650 sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
647651
648652 shape = (batch_size , channels , seq_length ) if self .channel_first else (batch_size , seq_length , channels )
649- return sample_fn (shape )
653+ return sample_fn (shape , model_forward_kwargs = model_forward_kwargs )
650654
651655 @torch .no_grad ()
652656 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
0 commit comments