@@ -699,7 +699,7 @@ def sample(self, classes, cond_scale = 6., rescaled_phi = 0.7):
699699 return sample_fn (classes , (batch_size , channels , image_size , image_size ), cond_scale , rescaled_phi )
700700
701701 @torch .no_grad ()
702- def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
702+ def interpolate (self , x1 , x2 , classes , t = None , lam = 0.5 ):
703703 b , * _ , device = * x1 .shape , x1 .device
704704 t = default (t , self .num_timesteps - 1 )
705705
@@ -709,8 +709,9 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
709709 xt1 , xt2 = map (lambda x : self .q_sample (x , t = t_batched ), (x1 , x2 ))
710710
711711 img = (1 - lam ) * xt1 + lam * xt2
712+
712713 for i in tqdm (reversed (range (0 , t )), desc = 'interpolation sample time step' , total = t ):
713- img = self .p_sample (img , torch . full (( b ,), i , device = device , dtype = torch . long ) )
714+ img , _ = self .p_sample (img , i , classes )
714715
715716 return img
716717
@@ -795,3 +796,12 @@ def forward(self, img, *args, **kwargs):
795796 )
796797
797798 sampled_images .shape # (8, 3, 128, 128)
799+
800+ # interpolation
801+
802+ interpolate_out = diffusion .interpolate (
803+ training_images [:1 ],
804+ training_images [:1 ],
805+ image_classes [:1 ]
806+ )
807+
0 commit comments