|
20 | 20 | from einops import rearrange, reduce, repeat |
21 | 21 | from einops.layers.torch import Rearrange |
22 | 22 |
|
| 23 | +from scipy.optimize import linear_sum_assignment |
| 24 | + |
23 | 25 | from PIL import Image |
24 | 26 | from tqdm.auto import tqdm |
25 | 27 | from ema_pytorch import EMA |
@@ -488,7 +490,8 @@ def __init__( |
488 | 490 | auto_normalize = True, |
489 | 491 | offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise |
490 | 492 | min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556 |
491 | | - min_snr_gamma = 5 |
| 493 | + min_snr_gamma = 5, |
| 494 | + immiscible = False |
492 | 495 | ): |
493 | 496 | super().__init__() |
494 | 497 | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) |
@@ -564,6 +567,10 @@ def __init__( |
564 | 567 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
565 | 568 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) |
566 | 569 |
|
| 570 | + # immiscible diffusion |
| 571 | + |
| 572 | + self.immiscible = immiscible |
| 573 | + |
567 | 574 | # offset noise strength - in blogpost, they claimed 0.1 was ideal |
568 | 575 |
|
569 | 576 | self.offset_noise_strength = offset_noise_strength |
@@ -759,10 +766,20 @@ def interpolate(self, x1, x2, t = None, lam = 0.5): |
759 | 766 |
|
760 | 767 | return img |
761 | 768 |
|
| 769 | + def noise_assignment(self, x_start, noise): |
| 770 | + x_start, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (x_start, noise)) |
| 771 | + dist = torch.cdist(x_start, noise) |
| 772 | + _, assign = linear_sum_assignment(dist.cpu()) |
| 773 | + return torch.from_numpy(assign).to(dist.device) |
| 774 | + |
762 | 775 | @autocast(enabled = False) |
763 | 776 | def q_sample(self, x_start, t, noise = None): |
764 | 777 | noise = default(noise, lambda: torch.randn_like(x_start)) |
765 | 778 |
|
| 779 | + if self.immiscible: |
| 780 | + assign = self.noise_assignment(x_start, noise) |
| 781 | + noise = noise[assign] |
| 782 | + |
766 | 783 | return ( |
767 | 784 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
768 | 785 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
|
0 commit comments