Skip to content

Commit 5b42eee

Browse files
committed
add flash attention, allow for fine customization of which layer of the unet full attention can be used
1 parent 60d544b commit 5b42eee

File tree

4 files changed

+178
-42
lines changed

4 files changed

+178
-42
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ sampled_seq.shape # (4, 32, 128)
153153
You could consider adding a suitable metric to the training loop yourself after doing an editable install of this package
154154
`pip install -e .`.
155155

156-
## Todo
157-
158-
- [ ] add flash attention, do full attention at 64x64, linear attention at anything above
159-
160156
## Citations
161157

162158
```bibtex
@@ -322,3 +318,12 @@ You could consider adding a suitable metric to the training loop yourself after
322318
year = {2023}
323319
}
324320
```
321+
322+
```bibtex
323+
@inproceedings{dao2022flashattention,
324+
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
325+
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
326+
booktitle = {Advances in Neural Information Processing Systems},
327+
year = {2022}
328+
}
329+
```
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from functools import wraps
2+
from packaging import version
3+
from collections import namedtuple
4+
5+
import torch
6+
from torch import nn, einsum
7+
import torch.nn.functional as F
8+
9+
from einops import rearrange
10+
11+
# constants
12+
13+
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14+
15+
# helpers
16+
17+
def exists(val):
18+
return val is not None
19+
20+
def once(fn):
21+
called = False
22+
@wraps(fn)
23+
def inner(x):
24+
nonlocal called
25+
if called:
26+
return
27+
called = True
28+
return fn(x)
29+
return inner
30+
31+
print_once = once(print)
32+
33+
# main class
34+
35+
class Attend(nn.Module):
36+
def __init__(
37+
self,
38+
dropout = 0.,
39+
flash = False
40+
):
41+
super().__init__()
42+
self.dropout = dropout
43+
self.attn_dropout = nn.Dropout(dropout)
44+
45+
self.flash = flash
46+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
47+
48+
# determine efficient attention configs for cuda and cpu
49+
50+
self.cpu_config = AttentionConfig(True, True, True)
51+
self.cuda_config = None
52+
53+
if not torch.cuda.is_available() or not flash:
54+
return
55+
56+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
57+
58+
if device_properties.major == 8 and device_properties.minor == 0:
59+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
60+
self.cuda_config = AttentionConfig(True, False, False)
61+
else:
62+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
63+
self.cuda_config = AttentionConfig(False, True, True)
64+
65+
def flash_attn(self, q, k, v):
66+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
67+
68+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
69+
70+
# Check if there is a compatible device for flash attention
71+
72+
config = self.cuda_config if is_cuda else self.cpu_config
73+
74+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
75+
76+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
77+
out = F.scaled_dot_product_attention(
78+
q, k, v,
79+
dropout_p = self.dropout if self.training else 0.
80+
)
81+
82+
return out
83+
84+
def forward(self, q, k, v):
85+
"""
86+
einstein notation
87+
b - batch
88+
h - heads
89+
n, i, j - sequence length (base sequence length, source, target)
90+
d - feature dimension
91+
"""
92+
93+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
94+
95+
if self.flash:
96+
return self.flash_attn(q, k, v)
97+
98+
scale = q.shape[-1] ** -0.5
99+
100+
# similarity
101+
102+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
103+
104+
# attention
105+
106+
attn = sim.softmax(dim = -1)
107+
attn = self.attn_dropout(attn)
108+
109+
# aggregate values
110+
111+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
112+
113+
return out

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from accelerate import Accelerator
2727

28+
from denoising_diffusion_pytorch.attend import Attend
2829
from denoising_diffusion_pytorch.fid_evaluation import FIDEvaluation
2930

3031
from denoising_diffusion_pytorch.version import __version__
@@ -43,6 +44,11 @@ def default(val, d):
4344
return val
4445
return d() if callable(d) else d
4546

47+
def cast_tuple(t, length = 1):
48+
if isinstance(t, tuple):
49+
return t
50+
return ((t,) * length)
51+
4652
def identity(t, *args, **kwargs):
4753
return t
4854

@@ -77,14 +83,6 @@ def unnormalize_to_zero_to_one(t):
7783

7884
# small helper modules
7985

80-
class Residual(nn.Module):
81-
def __init__(self, fn):
82-
super().__init__()
83-
self.fn = fn
84-
85-
def forward(self, x, *args, **kwargs):
86-
return self.fn(x, *args, **kwargs) + x
87-
8886
def Upsample(dim, dim_out = None):
8987
return nn.Sequential(
9088
nn.Upsample(scale_factor = 2, mode = 'nearest'),
@@ -105,16 +103,6 @@ def __init__(self, dim):
105103
def forward(self, x):
106104
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
107105

108-
class PreNorm(nn.Module):
109-
def __init__(self, dim, fn):
110-
super().__init__()
111-
self.fn = fn
112-
self.norm = RMSNorm(dim)
113-
114-
def forward(self, x):
115-
x = self.norm(x)
116-
return self.fn(x)
117-
118106
# sinusoidal positional embeds
119107

120108
class SinusoidalPosEmb(nn.Module):
@@ -195,11 +183,18 @@ def forward(self, x, time_emb = None):
195183
return h + self.res_conv(x)
196184

197185
class LinearAttention(nn.Module):
198-
def __init__(self, dim, heads = 4, dim_head = 32):
186+
def __init__(
187+
self,
188+
dim,
189+
heads = 4,
190+
dim_head = 32
191+
):
199192
super().__init__()
200193
self.scale = dim_head ** -0.5
201194
self.heads = heads
202195
hidden_dim = dim_head * heads
196+
197+
self.norm = RMSNorm(dim)
203198
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
204199

205200
self.to_out = nn.Sequential(
@@ -209,6 +204,9 @@ def __init__(self, dim, heads = 4, dim_head = 32):
209204

210205
def forward(self, x):
211206
b, c, h, w = x.shape
207+
208+
x = self.norm(x)
209+
212210
qkv = self.to_qkv(x).chunk(3, dim = 1)
213211
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
214212

@@ -224,25 +222,32 @@ def forward(self, x):
224222
return self.to_out(out)
225223

226224
class Attention(nn.Module):
227-
def __init__(self, dim, heads = 4, dim_head = 32):
225+
def __init__(
226+
self,
227+
dim,
228+
heads = 4,
229+
dim_head = 32,
230+
flash = False
231+
):
228232
super().__init__()
229-
self.scale = dim_head ** -0.5
230233
self.heads = heads
231234
hidden_dim = dim_head * heads
232235

236+
self.norm = RMSNorm(dim)
237+
self.attend = Attend(flash = flash)
238+
233239
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
234240
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
235241

236242
def forward(self, x):
237243
b, c, h, w = x.shape
238-
qkv = self.to_qkv(x).chunk(3, dim = 1)
239-
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
240244

241-
q = q * self.scale
245+
x = self.norm(x)
242246

243-
sim = einsum('b h d i, b h d j -> b h i j', q, k)
244-
attn = sim.softmax(dim = -1)
245-
out = einsum('b h i j, b h d j -> b h i d', attn, v)
247+
qkv = self.to_qkv(x).chunk(3, dim = 1)
248+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
249+
250+
out = self.attend(q, k, v)
246251

247252
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
248253
return self.to_out(out)
@@ -255,14 +260,16 @@ def __init__(
255260
dim,
256261
init_dim = None,
257262
out_dim = None,
258-
dim_mults=(1, 2, 4, 8),
263+
dim_mults = (1, 2, 4, 8),
259264
channels = 3,
260265
self_condition = False,
261266
resnet_block_groups = 8,
262267
learned_variance = False,
263268
learned_sinusoidal_cond = False,
264269
random_fourier_features = False,
265-
learned_sinusoidal_dim = 16
270+
learned_sinusoidal_dim = 16,
271+
full_attn = (False, False, False, True),
272+
flash_attn = False
266273
):
267274
super().__init__()
268275

@@ -300,34 +307,45 @@ def __init__(
300307
nn.Linear(time_dim, time_dim)
301308
)
302309

310+
# attention
311+
312+
full_attn = cast_tuple(full_attn, length = len(dim_mults))
313+
assert len(full_attn) == len(dim_mults)
314+
315+
FullAttention = partial(Attention, flash = flash_attn)
316+
303317
# layers
304318

305319
self.downs = nn.ModuleList([])
306320
self.ups = nn.ModuleList([])
307321
num_resolutions = len(in_out)
308322

309-
for ind, (dim_in, dim_out) in enumerate(in_out):
323+
for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(in_out, full_attn)):
310324
is_last = ind >= (num_resolutions - 1)
311325

326+
attn_klass = FullAttention if layer_full_attn else LinearAttention
327+
312328
self.downs.append(nn.ModuleList([
313329
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
314330
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
315-
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
331+
attn_klass(dim_in),
316332
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
317333
]))
318334

319335
mid_dim = dims[-1]
320336
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
321-
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
337+
self.mid_attn = FullAttention(mid_dim)
322338
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
323339

324-
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
340+
for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(reversed(in_out), reversed(full_attn))):
325341
is_last = ind == (len(in_out) - 1)
326342

343+
attn_klass = FullAttention if layer_full_attn else LinearAttention
344+
327345
self.ups.append(nn.ModuleList([
328346
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
329347
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
330-
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
348+
attn_klass(dim_out),
331349
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
332350
]))
333351

@@ -354,13 +372,13 @@ def forward(self, x, time, x_self_cond = None):
354372
h.append(x)
355373

356374
x = block2(x, t)
357-
x = attn(x)
375+
x = attn(x) + x
358376
h.append(x)
359377

360378
x = downsample(x)
361379

362380
x = self.mid_block1(x, t)
363-
x = self.mid_attn(x)
381+
x = self.mid_attn(x) + x
364382
x = self.mid_block2(x, t)
365383

366384
for block1, block2, attn, upsample in self.ups:
@@ -369,7 +387,7 @@ def forward(self, x, time, x_self_cond = None):
369387

370388
x = torch.cat((x, h.pop()), dim = 1)
371389
x = block2(x, t)
372-
x = attn(x)
390+
x = attn(x) + x
373391

374392
x = upsample(x)
375393

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.7.7'
1+
__version__ = '1.8.0'

0 commit comments

Comments
 (0)