22the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
33"""
44
5+ from copy import deepcopy
6+
57import math
68from math import sqrt , ceil
79from functools import partial
@@ -208,6 +210,7 @@ def __init__(
208210 attn_dim_head = 64 ,
209211 attn_res_mp_add_t = 0.3 ,
210212 attn_flash = False ,
213+ factorize_space_time_attn = False ,
211214 downsample = False ,
212215 downsample_config : Tuple [bool , bool , bool ] = (True , True , True )
213216 ):
@@ -247,15 +250,25 @@ def __init__(
247250 self .res_mp_add = MPAdd (t = mp_add_t )
248251
249252 self .attn = None
253+ self .factorized_attn = factorize_space_time_attn
254+
250255 if has_attn :
251- self . attn = Attention (
256+ attn_kwargs = dict (
252257 dim = dim_out ,
253258 heads = max (ceil (dim_out / attn_dim_head ), 2 ),
254259 dim_head = attn_dim_head ,
255260 mp_add_t = attn_res_mp_add_t ,
256261 flash = attn_flash
257262 )
258263
264+ if factorize_space_time_attn :
265+ self .attn = nn .ModuleList ([
266+ Attention (** attn_kwargs , only_space = True ),
267+ Attention (** attn_kwargs , only_time = True ),
268+ ])
269+ else :
270+ self .attn = Attention (** attn_kwargs )
271+
259272 def forward (
260273 self ,
261274 x ,
@@ -284,7 +297,13 @@ def forward(
284297 x = self .res_mp_add (x , res )
285298
286299 if exists (self .attn ):
287- x = self .attn (x )
300+ if self .factorized_attn :
301+ attn_space , attn_time = self .attn
302+ x = attn_space (x )
303+ x = attn_time (x )
304+
305+ else :
306+ x = self .attn (x )
288307
289308 return x
290309
@@ -301,6 +320,7 @@ def __init__(
301320 attn_dim_head = 64 ,
302321 attn_res_mp_add_t = 0.3 ,
303322 attn_flash = False ,
323+ factorize_space_time_attn = False ,
304324 upsample = False ,
305325 upsample_config : Tuple [bool , bool , bool ] = (True , True , True )
306326 ):
@@ -335,15 +355,25 @@ def __init__(
335355 self .res_mp_add = MPAdd (t = mp_add_t )
336356
337357 self .attn = None
358+ self .factorized_attn = factorize_space_time_attn
359+
338360 if has_attn :
339- self . attn = Attention (
361+ attn_kwargs = dict (
340362 dim = dim_out ,
341363 heads = max (ceil (dim_out / attn_dim_head ), 2 ),
342364 dim_head = attn_dim_head ,
343365 mp_add_t = attn_res_mp_add_t ,
344366 flash = attn_flash
345367 )
346368
369+ if factorize_space_time_attn :
370+ self .attn = nn .ModuleList ([
371+ Attention (** attn_kwargs , only_space = True ),
372+ Attention (** attn_kwargs , only_time = True ),
373+ ])
374+ else :
375+ self .attn = Attention (** attn_kwargs )
376+
347377 def forward (
348378 self ,
349379 x ,
@@ -369,7 +399,13 @@ def forward(
369399 x = self .res_mp_add (x , res )
370400
371401 if exists (self .attn ):
372- x = self .attn (x )
402+ if self .factorized_attn :
403+ attn_space , attn_time = self .attn
404+ x = attn_space (x )
405+ x = attn_time (x )
406+
407+ else :
408+ x = self .attn (x )
373409
374410 return x
375411
@@ -383,9 +419,13 @@ def __init__(
383419 dim_head = 64 ,
384420 num_mem_kv = 4 ,
385421 flash = False ,
386- mp_add_t = 0.3
422+ mp_add_t = 0.3 ,
423+ only_space = False ,
424+ only_time = False
387425 ):
388426 super ().__init__ ()
427+ assert (int (only_space ) + int (only_time )) <= 1
428+
389429 self .heads = heads
390430 hidden_dim = dim_head * heads
391431
@@ -399,20 +439,41 @@ def __init__(
399439
400440 self .mp_add = MPAdd (t = mp_add_t )
401441
442+ self .only_space = only_space
443+ self .only_time = only_time
444+
402445 def forward (self , x ):
403- res , b , c , t , h , w = x , * x .shape
446+ res , orig_shape = x , x .shape
447+ b , c , t , h , w = orig_shape
448+
449+ qkv = self .to_qkv (x )
450+
451+ if self .only_space :
452+ qkv = rearrange (qkv , 'b c t x y -> (b t) c x y' )
453+ elif self .only_time :
454+ qkv = rearrange (qkv , 'b c t x y -> (b x y) c t' )
455+
456+ qkv = qkv .chunk (3 , dim = 1 )
404457
405- qkv = self .to_qkv (x ).chunk (3 , dim = 1 )
406- q , k , v = map (lambda t : rearrange (t , 'b (h c) t x y -> b h (t x y) c' , h = self .heads ), qkv )
458+ q , k , v = map (lambda t : rearrange (t , 'b (h c) ... -> b h (...) c' , h = self .heads ), qkv )
459+
460+ mk , mv = map (lambda t : repeat (t , 'h n d -> b h n d' , b = k .shape [0 ]), self .mem_kv )
407461
408- mk , mv = map (lambda t : repeat (t , 'h n d -> b h n d' , b = b ), self .mem_kv )
409462 k , v = map (partial (torch .cat , dim = - 2 ), ((mk , k ), (mv , v )))
410463
411464 q , k , v = map (self .pixel_norm , (q , k , v ))
412465
413466 out = self .attend (q , k , v )
414467
415- out = rearrange (out , 'b h (t x y) d -> b (h d) t x y' , t = t , x = h , y = w )
468+ out = rearrange (out , 'b h n d -> b (h d) n' )
469+
470+ if self .only_space :
471+ out = rearrange (out , '(b t) c n -> b c (t n)' , t = t )
472+ elif self .only_time :
473+ out = rearrange (out , '(b x y) c n -> b c (n x y)' , x = h , y = w )
474+
475+ out = out .reshape (orig_shape )
476+
416477 out = self .to_out (out )
417478
418479 return self .mp_add (out , res )
@@ -446,7 +507,8 @@ def __init__(
446507 attn_res_mp_add_t = 0.3 ,
447508 resnet_mp_add_t = 0.3 ,
448509 dropout = 0.1 ,
449- self_condition = False
510+ self_condition = False ,
511+ factorize_space_time_attn = False
450512 ):
451513 super ().__init__ ()
452514
@@ -576,6 +638,7 @@ def __init__(
576638 has_attn = curr_image_res in attn_res ,
577639 upsample = True ,
578640 upsample_config = down_and_upsample_config ,
641+ factorize_space_time_attn = factorize_space_time_attn ,
579642 ** block_kwargs
580643 )
581644
@@ -593,6 +656,7 @@ def __init__(
593656 downsample = True ,
594657 downsample_config = down_and_upsample_config ,
595658 has_attn = has_attn ,
659+ factorize_space_time_attn = factorize_space_time_attn ,
596660 ** block_kwargs
597661 )
598662
@@ -777,6 +841,7 @@ def forward(self, x):
777841 ),
778842 attn_dim_head = 8 ,
779843 num_classes = 1000 ,
844+ factorize_space_time_attn = True # whether to do attention across space and time separately
780845 )
781846
782847 video = torch .randn (2 , 4 , 32 , 64 , 64 )
0 commit comments