Skip to content

Commit 5f8aa9f

Browse files
committed
complete #295
1 parent 354a39b commit 5f8aa9f

File tree

2 files changed

+77
-12
lines changed

2 files changed

+77
-12
lines changed

denoising_diffusion_pytorch/karras_unet_3d.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
33
"""
44

5+
from copy import deepcopy
6+
57
import math
68
from math import sqrt, ceil
79
from functools import partial
@@ -208,6 +210,7 @@ def __init__(
208210
attn_dim_head = 64,
209211
attn_res_mp_add_t = 0.3,
210212
attn_flash = False,
213+
factorize_space_time_attn = False,
211214
downsample = False,
212215
downsample_config: Tuple[bool, bool, bool] = (True, True, True)
213216
):
@@ -247,15 +250,25 @@ def __init__(
247250
self.res_mp_add = MPAdd(t = mp_add_t)
248251

249252
self.attn = None
253+
self.factorized_attn = factorize_space_time_attn
254+
250255
if has_attn:
251-
self.attn = Attention(
256+
attn_kwargs = dict(
252257
dim = dim_out,
253258
heads = max(ceil(dim_out / attn_dim_head), 2),
254259
dim_head = attn_dim_head,
255260
mp_add_t = attn_res_mp_add_t,
256261
flash = attn_flash
257262
)
258263

264+
if factorize_space_time_attn:
265+
self.attn = nn.ModuleList([
266+
Attention(**attn_kwargs, only_space = True),
267+
Attention(**attn_kwargs, only_time = True),
268+
])
269+
else:
270+
self.attn = Attention(**attn_kwargs)
271+
259272
def forward(
260273
self,
261274
x,
@@ -284,7 +297,13 @@ def forward(
284297
x = self.res_mp_add(x, res)
285298

286299
if exists(self.attn):
287-
x = self.attn(x)
300+
if self.factorized_attn:
301+
attn_space, attn_time = self.attn
302+
x = attn_space(x)
303+
x = attn_time(x)
304+
305+
else:
306+
x = self.attn(x)
288307

289308
return x
290309

@@ -301,6 +320,7 @@ def __init__(
301320
attn_dim_head = 64,
302321
attn_res_mp_add_t = 0.3,
303322
attn_flash = False,
323+
factorize_space_time_attn = False,
304324
upsample = False,
305325
upsample_config: Tuple[bool, bool, bool] = (True, True, True)
306326
):
@@ -335,15 +355,25 @@ def __init__(
335355
self.res_mp_add = MPAdd(t = mp_add_t)
336356

337357
self.attn = None
358+
self.factorized_attn = factorize_space_time_attn
359+
338360
if has_attn:
339-
self.attn = Attention(
361+
attn_kwargs = dict(
340362
dim = dim_out,
341363
heads = max(ceil(dim_out / attn_dim_head), 2),
342364
dim_head = attn_dim_head,
343365
mp_add_t = attn_res_mp_add_t,
344366
flash = attn_flash
345367
)
346368

369+
if factorize_space_time_attn:
370+
self.attn = nn.ModuleList([
371+
Attention(**attn_kwargs, only_space = True),
372+
Attention(**attn_kwargs, only_time = True),
373+
])
374+
else:
375+
self.attn = Attention(**attn_kwargs)
376+
347377
def forward(
348378
self,
349379
x,
@@ -369,7 +399,13 @@ def forward(
369399
x = self.res_mp_add(x, res)
370400

371401
if exists(self.attn):
372-
x = self.attn(x)
402+
if self.factorized_attn:
403+
attn_space, attn_time = self.attn
404+
x = attn_space(x)
405+
x = attn_time(x)
406+
407+
else:
408+
x = self.attn(x)
373409

374410
return x
375411

@@ -383,9 +419,13 @@ def __init__(
383419
dim_head = 64,
384420
num_mem_kv = 4,
385421
flash = False,
386-
mp_add_t = 0.3
422+
mp_add_t = 0.3,
423+
only_space = False,
424+
only_time = False
387425
):
388426
super().__init__()
427+
assert (int(only_space) + int(only_time)) <= 1
428+
389429
self.heads = heads
390430
hidden_dim = dim_head * heads
391431

@@ -399,20 +439,41 @@ def __init__(
399439

400440
self.mp_add = MPAdd(t = mp_add_t)
401441

442+
self.only_space = only_space
443+
self.only_time = only_time
444+
402445
def forward(self, x):
403-
res, b, c, t, h, w = x, *x.shape
446+
res, orig_shape = x, x.shape
447+
b, c, t, h, w = orig_shape
448+
449+
qkv = self.to_qkv(x)
450+
451+
if self.only_space:
452+
qkv = rearrange(qkv, 'b c t x y -> (b t) c x y')
453+
elif self.only_time:
454+
qkv = rearrange(qkv, 'b c t x y -> (b x y) c t')
455+
456+
qkv = qkv.chunk(3, dim = 1)
404457

405-
qkv = self.to_qkv(x).chunk(3, dim = 1)
406-
q, k, v = map(lambda t: rearrange(t, 'b (h c) t x y -> b h (t x y) c', h = self.heads), qkv)
458+
q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), qkv)
459+
460+
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = k.shape[0]), self.mem_kv)
407461

408-
mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
409462
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))
410463

411464
q, k, v = map(self.pixel_norm, (q, k, v))
412465

413466
out = self.attend(q, k, v)
414467

415-
out = rearrange(out, 'b h (t x y) d -> b (h d) t x y', t = t, x = h, y = w)
468+
out = rearrange(out, 'b h n d -> b (h d) n')
469+
470+
if self.only_space:
471+
out = rearrange(out, '(b t) c n -> b c (t n)', t = t)
472+
elif self.only_time:
473+
out = rearrange(out, '(b x y) c n -> b c (n x y)', x = h, y = w)
474+
475+
out = out.reshape(orig_shape)
476+
416477
out = self.to_out(out)
417478

418479
return self.mp_add(out, res)
@@ -446,7 +507,8 @@ def __init__(
446507
attn_res_mp_add_t = 0.3,
447508
resnet_mp_add_t = 0.3,
448509
dropout = 0.1,
449-
self_condition = False
510+
self_condition = False,
511+
factorize_space_time_attn = False
450512
):
451513
super().__init__()
452514

@@ -576,6 +638,7 @@ def __init__(
576638
has_attn = curr_image_res in attn_res,
577639
upsample = True,
578640
upsample_config = down_and_upsample_config,
641+
factorize_space_time_attn = factorize_space_time_attn,
579642
**block_kwargs
580643
)
581644

@@ -593,6 +656,7 @@ def __init__(
593656
downsample = True,
594657
downsample_config = down_and_upsample_config,
595658
has_attn = has_attn,
659+
factorize_space_time_attn = factorize_space_time_attn,
596660
**block_kwargs
597661
)
598662

@@ -777,6 +841,7 @@ def forward(self, x):
777841
),
778842
attn_dim_head = 8,
779843
num_classes = 1000,
844+
factorize_space_time_attn = True # whether to do attention across space and time separately
780845
)
781846

782847
video = torch.randn(2, 4, 32, 64, 64)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.17'
1+
__version__ = '1.11.0'

0 commit comments

Comments
 (0)