diff --git a/modules/core/reflow.py b/modules/core/reflow.py index f09eb2392..c61e656dc 100644 --- a/modules/core/reflow.py +++ b/modules/core/reflow.py @@ -100,21 +100,51 @@ def sample_rk5(self, x, t, dt, cond): x += (7 * k_1 + 32 * k_3 + 12 * k_4 + 32 * k_5 + 7 * k_6) * dt / 90 t += dt return x, t - + def inpaint_fn(self,x,mask,inpaint_input,noise,t): + x_m=x*mask + x=(1-mask)*(t*inpaint_input+(1-t)*noise)+x_m + return x + pass @torch.no_grad() - def inference(self, cond, b=1, x_end=None, device=None): + def inference(self, cond, b=1, x_end=None, device=None,inpaint_mask=None,inpaint_input=None): noise = torch.randn(b, self.num_feats, self.out_dims, cond.shape[2], device=device) t_start = hparams.get('T_start_infer', self.t_start) - if self.use_shallow_diffusion and t_start > 0: - assert x_end is not None, 'Missing shallow diffusion source.' + if inpaint_mask is not None: + assert inpaint_input is not None + inpaint_mask=inpaint_mask.float() + if inpaint_mask is None: + if self.use_shallow_diffusion and t_start > 0: + assert x_end is not None, 'Missing shallow diffusion source.' + if t_start >= 1.: + t_start = 1. + x = x_end + else: + x = t_start * x_end + (1 - t_start) * noise + else: + t_start = 0. + x = noise + else: if t_start >= 1.: t_start = 1. - x = x_end + x = inpaint_input else: - x = t_start * x_end + (1 - t_start) * noise - else: - t_start = 0. - x = noise + if x_end is not None: + if t_start > 0: + x_m=(t_start * x_end + (1 - t_start) * noise)*inpaint_mask + x=(1-inpaint_mask)*(t_start * inpaint_input + (1 - t_start) * noise)+x_m + else: + x= noise + t_start = 0. + else: + if t_start > 0: + x_m=(t_start * inpaint_input + (1 - t_start) * noise)*inpaint_mask + x=(1-inpaint_mask)*(t_start * inpaint_input + (1 - t_start) * noise)+x_m + else: + x= noise + t_start = 0. + + + algorithm = hparams['sampling_algorithm'] infer_step = hparams['sampling_steps'] @@ -133,6 +163,8 @@ def inference(self, cond, b=1, x_end=None, device=None): for i in tqdm(range(infer_step), desc='sample time step', total=infer_step, disable=not hparams['infer'], leave=False): x, _ = algorithm_fn(x, t_start + i * dts, dt, cond) + if inpaint_mask is not None: + x=self.inpaint_fn(x,inpaint_mask,inpaint_input,noise,t_start+i*dts) x = x.float() x = x.transpose(2, 3).squeeze(1) # [B, F, M, T] => [B, T, M] or [B, F, T, M] return x