Skip to content

Commit bb7a9be

Browse files
authored
Merge pull request #287 from lucidrains/karras-magnitude-preserving-unet
complete karras unet
2 parents 32310c3 + 42a9e79 commit bb7a9be

File tree

5 files changed

+669
-3
lines changed

5 files changed

+669
-3
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,14 @@ You could consider adding a suitable metric to the training loop yourself after
343343
url = {https://api.semanticscholar.org/CorpusID:259224568}
344344
}
345345
```
346+
347+
```bibtex
348+
@article{Karras2023AnalyzingAI,
349+
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
350+
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
351+
journal = {ArXiv},
352+
year = {2023},
353+
volume = {abs/2312.02696},
354+
url = {https://api.semanticscholar.org/CorpusID:265659032}
355+
}
356+
```

denoising_diffusion_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion
88

99
from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D
10+
11+
from denoising_diffusion_pytorch.karras_unet import KarrasUnet

denoising_diffusion_pytorch/attend.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
def exists(val):
1818
return val is not None
1919

20+
def default(val, d):
21+
return val if exists(val) else d
22+
2023
def once(fn):
2124
called = False
2225
@wraps(fn)
@@ -36,10 +39,12 @@ class Attend(nn.Module):
3639
def __init__(
3740
self,
3841
dropout = 0.,
39-
flash = False
42+
flash = False,
43+
scale = None
4044
):
4145
super().__init__()
4246
self.dropout = dropout
47+
self.scale = scale
4348
self.attn_dropout = nn.Dropout(dropout)
4449

4550
self.flash = flash
@@ -65,6 +70,10 @@ def __init__(
6570
def flash_attn(self, q, k, v):
6671
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
6772

73+
if exists(self.scale):
74+
default_scale = q.shape[-1]
75+
q = q * (scale / default_scale)
76+
6877
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
6978

7079
# Check if there is a compatible device for flash attention
@@ -95,7 +104,7 @@ def forward(self, q, k, v):
95104
if self.flash:
96105
return self.flash_attn(q, k, v)
97106

98-
scale = q.shape[-1] ** -0.5
107+
scale = default(self.scale, q.shape[-1] ** -0.5)
99108

100109
# similarity
101110

0 commit comments

Comments
 (0)