Skip to content

Commit 5a0e07f

Browse files
committed
add immiscible diffusion
1 parent ec0a1c7 commit 5a0e07f

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,14 @@ You could consider adding a suitable metric to the training loop yourself after
355355
url = {https://api.semanticscholar.org/CorpusID:265659032}
356356
}
357357
```
358+
359+
```bibtex
360+
@article{Li2024ImmiscibleDA,
361+
title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment},
362+
author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu},
363+
journal = {ArXiv},
364+
year = {2024},
365+
volume = {abs/2406.12303},
366+
url = {https://api.semanticscholar.org/CorpusID:270562607}
367+
}
368+
```

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from einops import rearrange, reduce, repeat
2121
from einops.layers.torch import Rearrange
2222

23+
from scipy.optimize import linear_sum_assignment
24+
2325
from PIL import Image
2426
from tqdm.auto import tqdm
2527
from ema_pytorch import EMA
@@ -488,7 +490,8 @@ def __init__(
488490
auto_normalize = True,
489491
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
490492
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
491-
min_snr_gamma = 5
493+
min_snr_gamma = 5,
494+
immiscible = False
492495
):
493496
super().__init__()
494497
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
@@ -564,6 +567,10 @@ def __init__(
564567
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
565568
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
566569

570+
# immiscible diffusion
571+
572+
self.immiscible = immiscible
573+
567574
# offset noise strength - in blogpost, they claimed 0.1 was ideal
568575

569576
self.offset_noise_strength = offset_noise_strength
@@ -759,10 +766,20 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
759766

760767
return img
761768

769+
def noise_assignment(self, x_start, noise):
770+
x_start, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (x_start, noise))
771+
dist = torch.cdist(x_start, noise)
772+
_, assign = linear_sum_assignment(dist.cpu())
773+
return torch.from_numpy(assign).to(dist.device)
774+
762775
@autocast(enabled = False)
763776
def q_sample(self, x_start, t, noise = None):
764777
noise = default(noise, lambda: torch.randn_like(x_start))
765778

779+
if self.immiscible:
780+
assign = self.noise_assignment(x_start, noise)
781+
noise = noise[assign]
782+
766783
return (
767784
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
768785
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.12'
1+
__version__ = '2.0.15'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'numpy',
2424
'pillow',
2525
'pytorch-fid',
26+
'scipy',
2627
'torch',
2728
'torchvision',
2829
'tqdm'

0 commit comments

Comments
 (0)