@@ -423,12 +423,18 @@ def __init__(
423423 objective = 'pred_noise' ,
424424 beta_schedule = 'cosine' ,
425425 ddim_sampling_eta = 0. ,
426- auto_normalize = True
426+ auto_normalize = True ,
427+ channels = None ,
428+ self_condition = None ,
429+ channel_first = True
427430 ):
428431 super ().__init__ ()
429432 self .model = model
430- self .channels = self .model .channels
431- self .self_condition = self .model .self_condition
433+ self .channels = default (channels , lambda : self .model .channels )
434+ self .self_condition = default (self_condition , lambda : self .model .self_condition )
435+
436+ self .channel_first = channel_first
437+ self .seq_index = - 2 if not channel_first else - 1
432438
433439 self .seq_length = seq_length
434440
@@ -638,7 +644,9 @@ def ddim_sample(self, shape, clip_denoised = True):
638644 def sample (self , batch_size = 16 ):
639645 seq_length , channels = self .seq_length , self .channels
640646 sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
641- return sample_fn ((batch_size , channels , seq_length ))
647+
648+ shape = (batch_size , channels , seq_length ) if self .channel_first else (batch_size , seq_length , channels )
649+ return sample_fn (shape )
642650
643651 @torch .no_grad ()
644652 def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
@@ -669,8 +677,10 @@ def q_sample(self, x_start, t, noise=None):
669677 extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
670678 )
671679
672- def p_losses (self , x_start , t , noise = None ):
673- b , c , n = x_start .shape
680+ def p_losses (self , x_start , t , noise = None , model_forward_kwargs : dict = dict ()):
681+ b = x_start .shape [0 ]
682+ n = x_start .shape [self .seq_index ]
683+
674684 noise = default (noise , lambda : torch .randn_like (x_start ))
675685
676686 # noise sample
@@ -687,9 +697,13 @@ def p_losses(self, x_start, t, noise = None):
687697 x_self_cond = self .model_predictions (x , t ).pred_x_start
688698 x_self_cond .detach_ ()
689699
700+ model_forward_kwargs = {** model_forward_kwargs , 'self_cond' : x_self_cond }
701+
702+ # model kwargs
703+
690704 # predict and take gradient step
691705
692- model_out = self .model (x , t , x_self_cond )
706+ model_out = self .model (x , t , ** model_forward_kwargs )
693707
694708 if self .objective == 'pred_noise' :
695709 target = noise
@@ -708,7 +722,8 @@ def p_losses(self, x_start, t, noise = None):
708722 return loss .mean ()
709723
710724 def forward (self , img , * args , ** kwargs ):
711- b , c , n , device , seq_length , = * img .shape , img .device , self .seq_length
725+ b , n , device , seq_length , = img .shape [0 ], img .shape [self .seq_index ], img .device , self .seq_length
726+
712727 assert n == seq_length , f'seq length must be { seq_length } '
713728 t = torch .randint (0 , self .num_timesteps , (b ,), device = device ).long ()
714729
0 commit comments