Skip to content

Commit b21954a

Browse files
authored
Merge pull request #322 from AlejandroSantorum/main
Fix RePaint: use ground-truth in the last in-painting step
2 parents df09945 + 1b1f469 commit b21954a

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

denoising_diffusion_pytorch/repaint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,9 +677,16 @@ def p_sample(self, x, t: int, x_self_cond = None, gt=None, mask=None):
677677

678678
b, *_, device = *x.shape, self.device
679679
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
680-
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
680+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
681+
x=x, t=batched_times, x_self_cond=x_self_cond, clip_denoised=True
682+
)
681683
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
682684
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
685+
686+
if t==0 and mask is not None:
687+
# if t == 0, we use the ground-truth image if in-painting
688+
pred_img = (mask * gt) + ((1 - mask) * pred_img)
689+
683690
return pred_img, x_start
684691

685692
@torch.inference_mode()
@@ -707,7 +714,7 @@ def p_sample_loop(
707714
imgs.append(img)
708715

709716
# Resampling loop: line 9 of Algorithm 1 in https://arxiv.org/pdf/2201.09865
710-
if resample is True and (t > 0) and (t % resample_every == 0):
717+
if resample is True and (t > 0) and (t % resample_every == 0 or t == 1):
711718
# Jump back for resample_jump timesteps and resample_iter times
712719
for iter in tqdm(range(resample_iter), desc = 'resample loop', total = resample_iter):
713720
t = resample_jump

0 commit comments

Comments
 (0)