Skip to content

Commit d83751e

Browse files
committed
quick fix for interpolate in cfg example
1 parent daf2d28 commit d83751e

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def sample(self, classes, cond_scale = 6., rescaled_phi = 0.7):
699699
return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale, rescaled_phi)
700700

701701
@torch.no_grad()
702-
def interpolate(self, x1, x2, t = None, lam = 0.5):
702+
def interpolate(self, x1, x2, classes, t = None, lam = 0.5):
703703
b, *_, device = *x1.shape, x1.device
704704
t = default(t, self.num_timesteps - 1)
705705

@@ -709,8 +709,9 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
709709
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
710710

711711
img = (1 - lam) * xt1 + lam * xt2
712+
712713
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
713-
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
714+
img, _ = self.p_sample(img, i, classes)
714715

715716
return img
716717

@@ -795,3 +796,12 @@ def forward(self, img, *args, **kwargs):
795796
)
796797

797798
sampled_images.shape # (8, 3, 128, 128)
799+
800+
# interpolation
801+
802+
interpolate_out = diffusion.interpolate(
803+
training_images[:1],
804+
training_images[:1],
805+
image_classes[:1]
806+
)
807+
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.14'
1+
__version__ = '1.8.15'

0 commit comments

Comments
 (0)