@@ -677,9 +677,16 @@ def p_sample(self, x, t: int, x_self_cond = None, gt=None, mask=None):
677677
678678 b , * _ , device = * x .shape , self .device
679679 batched_times = torch .full ((b ,), t , device = device , dtype = torch .long )
680- model_mean , _ , model_log_variance , x_start = self .p_mean_variance (x = x , t = batched_times , x_self_cond = x_self_cond , clip_denoised = True )
680+ model_mean , _ , model_log_variance , x_start = self .p_mean_variance (
681+ x = x , t = batched_times , x_self_cond = x_self_cond , clip_denoised = True
682+ )
681683 noise = torch .randn_like (x ) if t > 0 else 0. # no noise if t == 0
682684 pred_img = model_mean + (0.5 * model_log_variance ).exp () * noise
685+
686+ if t == 0 and mask is not None :
687+ # if t == 0, we use the ground-truth image if in-painting
688+ pred_img = (mask * gt ) + ((1 - mask ) * pred_img )
689+
683690 return pred_img , x_start
684691
685692 @torch .inference_mode ()
@@ -707,7 +714,7 @@ def p_sample_loop(
707714 imgs .append (img )
708715
709716 # Resampling loop: line 9 of Algorithm 1 in https://arxiv.org/pdf/2201.09865
710- if resample is True and (t > 0 ) and (t % resample_every == 0 ):
717+ if resample is True and (t > 0 ) and (t % resample_every == 0 or t == 1 ):
711718 # Jump back for resample_jump timesteps and resample_iter times
712719 for iter in tqdm (range (resample_iter ), desc = 'resample loop' , total = resample_iter ):
713720 t = resample_jump
0 commit comments