|
8 | 8 | from torch import Tensor |
9 | 9 | from tqdm import tqdm |
10 | 10 |
|
| 11 | +from .utils import default |
| 12 | + |
11 | 13 | """ Distributions """ |
12 | 14 |
|
13 | 15 |
|
@@ -166,6 +168,7 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: |
166 | 168 | alpha, beta = torch.cos(angle), torch.sin(angle) |
167 | 169 | return alpha, beta |
168 | 170 |
|
| 171 | + @torch.no_grad() |
169 | 172 | def forward( # type: ignore |
170 | 173 | self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs |
171 | 174 | ) -> Tensor: |
@@ -242,6 +245,7 @@ def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor: |
242 | 245 | # Sample start |
243 | 246 | return self.sample_loop(current=noise, sigmas=sigmas, **kwargs) |
244 | 247 |
|
| 248 | + @torch.no_grad() |
245 | 249 | def forward( |
246 | 250 | self, |
247 | 251 | num_items: int, |
@@ -289,3 +293,61 @@ def forward( |
289 | 293 | chunks += [torch.randn(shape, device=self.device)] |
290 | 294 |
|
291 | 295 | return torch.cat(chunks[:num_chunks], dim=-1) |
| 296 | + |
| 297 | + |
| 298 | +""" Inpainters """ |
| 299 | + |
| 300 | + |
| 301 | +class Inpainter(nn.Module): |
| 302 | + pass |
| 303 | + |
| 304 | + |
| 305 | +class VInpainter(Inpainter): |
| 306 | + |
| 307 | + diffusion_types = [VDiffusion] |
| 308 | + |
| 309 | + def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): |
| 310 | + super().__init__() |
| 311 | + self.net = net |
| 312 | + self.schedule = schedule |
| 313 | + |
| 314 | + def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: |
| 315 | + angle = sigmas * pi / 2 |
| 316 | + alpha, beta = torch.cos(angle), torch.sin(angle) |
| 317 | + return alpha, beta |
| 318 | + |
| 319 | + @torch.no_grad() |
| 320 | + def forward( # type: ignore |
| 321 | + self, |
| 322 | + source: Tensor, |
| 323 | + mask: Tensor, |
| 324 | + num_steps: int, |
| 325 | + num_resamples: int, |
| 326 | + show_progress: bool = False, |
| 327 | + x_noisy: Optional[Tensor] = None, |
| 328 | + **kwargs, |
| 329 | + ) -> Tensor: |
| 330 | + x_noisy = default(x_noisy, lambda: torch.randn_like(source)) |
| 331 | + b = x_noisy.shape[0] |
| 332 | + sigmas = self.schedule(num_steps + 1, device=x_noisy.device) |
| 333 | + sigmas = repeat(sigmas, "i -> i b", b=b) |
| 334 | + sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) |
| 335 | + alphas, betas = self.get_alpha_beta(sigmas_batch) |
| 336 | + progress_bar = tqdm(range(num_steps), disable=not show_progress) |
| 337 | + |
| 338 | + for i in progress_bar: |
| 339 | + for r in range(num_resamples): |
| 340 | + v_pred = self.net(x_noisy, sigmas[i], **kwargs) |
| 341 | + x_pred = alphas[i] * x_noisy - betas[i] * v_pred |
| 342 | + noise_pred = betas[i] * x_noisy + alphas[i] * v_pred |
| 343 | + # Renoise to current noise level if resampling |
| 344 | + j = r == num_resamples - 1 |
| 345 | + x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred |
| 346 | + s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like( |
| 347 | + source |
| 348 | + ) |
| 349 | + x_noisy = s_noisy * mask + x_noisy * ~mask |
| 350 | + |
| 351 | + progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})") |
| 352 | + |
| 353 | + return x_noisy |
0 commit comments