@@ -315,6 +315,55 @@ def __init__(
315315"""
316316
317317
318+ class RelativePositionBias (nn .Module ):
319+ def __init__ (self , num_buckets : int , max_distance : int , num_heads : int ):
320+ super ().__init__ ()
321+ self .num_buckets = num_buckets
322+ self .max_distance = max_distance
323+ self .num_heads = num_heads
324+ self .relative_attention_bias = nn .Embedding (num_buckets , num_heads )
325+
326+ @staticmethod
327+ def _relative_position_bucket (
328+ relative_position : Tensor , num_buckets : int , max_distance : int
329+ ):
330+ num_buckets //= 2
331+ ret = (relative_position >= 0 ).to (torch .long ) * num_buckets
332+ n = torch .abs (relative_position )
333+
334+ max_exact = num_buckets // 2
335+ is_small = n < max_exact
336+
337+ val_if_large = (
338+ max_exact
339+ + (
340+ torch .log (n .float () / max_exact )
341+ / math .log (max_distance / max_exact )
342+ * (num_buckets - max_exact )
343+ ).long ()
344+ )
345+ val_if_large = torch .min (
346+ val_if_large , torch .full_like (val_if_large , num_buckets - 1 )
347+ )
348+
349+ ret += torch .where (is_small , n , val_if_large )
350+ return ret
351+
352+ def forward (self , num_queries : int , num_keys : int ) -> Tensor :
353+ i , j , device = num_queries , num_keys , self .relative_attention_bias .weight .device
354+ q_pos = torch .arange (j - i , j , dtype = torch .long , device = device )
355+ k_pos = torch .arange (j , dtype = torch .long , device = device )
356+ rel_pos = rearrange (k_pos , "j -> 1 j" ) - rearrange (q_pos , "i -> i 1" )
357+
358+ relative_position_bucket = self ._relative_position_bucket (
359+ rel_pos , num_buckets = self .num_buckets , max_distance = self .max_distance
360+ )
361+
362+ bias = self .relative_attention_bias (relative_position_bucket )
363+ bias = rearrange (bias , "m n h -> 1 h m n" )
364+ return bias
365+
366+
318367def FeedForward (features : int , multiplier : int ) -> nn .Module :
319368 mid_features = features * multiplier
320369 return nn .Sequential (
@@ -331,19 +380,33 @@ def __init__(
331380 * ,
332381 head_features : int ,
333382 num_heads : int ,
383+ use_rel_pos : bool ,
384+ rel_pos_num_buckets : Optional [int ] = None ,
385+ rel_pos_max_distance : Optional [int ] = None ,
334386 ):
335387 super ().__init__ ()
336388 self .scale = head_features ** - 0.5
337389 self .num_heads = num_heads
390+ self .use_rel_pos = use_rel_pos
338391 mid_features = head_features * num_heads
339392
393+ if use_rel_pos :
394+ assert exists (rel_pos_num_buckets ) and exists (rel_pos_max_distance )
395+ self .rel_pos = RelativePositionBias (
396+ num_buckets = rel_pos_num_buckets ,
397+ max_distance = rel_pos_max_distance ,
398+ num_heads = num_heads ,
399+ )
400+
340401 self .to_out = nn .Linear (in_features = mid_features , out_features = features )
341402
342403 def forward (self , q : Tensor , k : Tensor , v : Tensor ) -> Tensor :
343404 # Split heads
344405 q , k , v = rearrange_many ((q , k , v ), "b n (h d) -> b h n d" , h = self .num_heads )
345406 # Compute similarity matrix
346- sim = einsum ("... n d, ... m d -> ... n m" , q , k ) * self .scale
407+ sim = einsum ("... n d, ... m d -> ... n m" , q , k )
408+ sim = (sim + self .rel_pos (* sim .shape [- 2 :])) if self .use_rel_pos else sim
409+ sim = sim * self .scale
347410 # Get attention matrix with softmax
348411 attn = sim .softmax (dim = - 1 )
349412 # Compute values
@@ -360,6 +423,9 @@ def __init__(
360423 head_features : int ,
361424 num_heads : int ,
362425 context_features : Optional [int ] = None ,
426+ use_rel_pos : bool ,
427+ rel_pos_num_buckets : Optional [int ] = None ,
428+ rel_pos_max_distance : Optional [int ] = None ,
363429 ):
364430 super ().__init__ ()
365431 self .context_features = context_features
@@ -375,7 +441,12 @@ def __init__(
375441 in_features = context_features , out_features = mid_features * 2 , bias = False
376442 )
377443 self .attention = AttentionBase (
378- features , num_heads = num_heads , head_features = head_features
444+ features ,
445+ num_heads = num_heads ,
446+ head_features = head_features ,
447+ use_rel_pos = use_rel_pos ,
448+ rel_pos_num_buckets = rel_pos_num_buckets ,
449+ rel_pos_max_distance = rel_pos_max_distance ,
379450 )
380451
381452 def forward (self , x : Tensor , * , context : Optional [Tensor ] = None ) -> Tensor :
@@ -402,14 +473,22 @@ def __init__(
402473 num_heads : int ,
403474 head_features : int ,
404475 multiplier : int ,
476+ use_rel_pos : bool ,
477+ rel_pos_num_buckets : Optional [int ] = None ,
478+ rel_pos_max_distance : Optional [int ] = None ,
405479 context_features : Optional [int ] = None ,
406480 ):
407481 super ().__init__ ()
408482
409483 self .use_cross_attention = exists (context_features ) and context_features > 0
410484
411485 self .attention = Attention (
412- features = features , num_heads = num_heads , head_features = head_features
486+ features = features ,
487+ num_heads = num_heads ,
488+ head_features = head_features ,
489+ use_rel_pos = use_rel_pos ,
490+ rel_pos_num_buckets = rel_pos_num_buckets ,
491+ rel_pos_max_distance = rel_pos_max_distance ,
413492 )
414493
415494 if self .use_cross_attention :
@@ -418,6 +497,9 @@ def __init__(
418497 num_heads = num_heads ,
419498 head_features = head_features ,
420499 context_features = context_features ,
500+ use_rel_pos = use_rel_pos ,
501+ rel_pos_num_buckets = rel_pos_num_buckets ,
502+ rel_pos_max_distance = rel_pos_max_distance ,
421503 )
422504
423505 self .feed_forward = FeedForward (features = features , multiplier = multiplier )
@@ -443,6 +525,9 @@ def __init__(
443525 num_heads : int ,
444526 head_features : int ,
445527 multiplier : int ,
528+ use_rel_pos : bool ,
529+ rel_pos_num_buckets : Optional [int ] = None ,
530+ rel_pos_max_distance : Optional [int ] = None ,
446531 context_features : Optional [int ] = None ,
447532 ):
448533 super ().__init__ ()
@@ -465,6 +550,9 @@ def __init__(
465550 num_heads = num_heads ,
466551 multiplier = multiplier ,
467552 context_features = context_features ,
553+ use_rel_pos = use_rel_pos ,
554+ rel_pos_num_buckets = rel_pos_num_buckets ,
555+ rel_pos_max_distance = rel_pos_max_distance ,
468556 )
469557 for i in range (num_layers )
470558 ]
@@ -552,6 +640,9 @@ def __init__(
552640 attention_heads : Optional [int ] = None ,
553641 attention_features : Optional [int ] = None ,
554642 attention_multiplier : Optional [int ] = None ,
643+ attention_use_rel_pos : Optional [bool ] = None ,
644+ attention_rel_pos_max_distance : Optional [int ] = None ,
645+ attention_rel_pos_num_buckets : Optional [int ] = None ,
555646 context_mapping_features : Optional [int ] = None ,
556647 context_embedding_features : Optional [int ] = None ,
557648 ):
@@ -588,6 +679,7 @@ def __init__(
588679 exists (attention_heads )
589680 and exists (attention_features )
590681 and exists (attention_multiplier )
682+ and exists (attention_use_rel_pos )
591683 )
592684 self .transformer = Transformer1d (
593685 num_layers = num_transformer_blocks ,
@@ -596,6 +688,9 @@ def __init__(
596688 head_features = attention_features ,
597689 multiplier = attention_multiplier ,
598690 context_features = context_embedding_features ,
691+ use_rel_pos = attention_use_rel_pos ,
692+ rel_pos_num_buckets = attention_rel_pos_num_buckets ,
693+ rel_pos_max_distance = attention_rel_pos_max_distance ,
599694 )
600695
601696 if self .use_extract :
@@ -659,6 +754,9 @@ def __init__(
659754 attention_heads : Optional [int ] = None ,
660755 attention_features : Optional [int ] = None ,
661756 attention_multiplier : Optional [int ] = None ,
757+ attention_use_rel_pos : Optional [bool ] = None ,
758+ attention_rel_pos_max_distance : Optional [int ] = None ,
759+ attention_rel_pos_num_buckets : Optional [int ] = None ,
662760 context_mapping_features : Optional [int ] = None ,
663761 context_embedding_features : Optional [int ] = None ,
664762 ):
@@ -689,6 +787,7 @@ def __init__(
689787 exists (attention_heads )
690788 and exists (attention_features )
691789 and exists (attention_multiplier )
790+ and exists (attention_use_rel_pos )
692791 )
693792 self .transformer = Transformer1d (
694793 num_layers = num_transformer_blocks ,
@@ -697,6 +796,9 @@ def __init__(
697796 head_features = attention_features ,
698797 multiplier = attention_multiplier ,
699798 context_features = context_embedding_features ,
799+ use_rel_pos = attention_use_rel_pos ,
800+ rel_pos_num_buckets = attention_rel_pos_num_buckets ,
801+ rel_pos_max_distance = attention_rel_pos_max_distance ,
700802 )
701803
702804 self .upsample = Upsample1d (
@@ -756,6 +858,9 @@ def __init__(
756858 attention_heads : Optional [int ] = None ,
757859 attention_features : Optional [int ] = None ,
758860 attention_multiplier : Optional [int ] = None ,
861+ attention_use_rel_pos : Optional [bool ] = None ,
862+ attention_rel_pos_max_distance : Optional [int ] = None ,
863+ attention_rel_pos_num_buckets : Optional [int ] = None ,
759864 context_mapping_features : Optional [int ] = None ,
760865 context_embedding_features : Optional [int ] = None ,
761866 ):
@@ -774,6 +879,7 @@ def __init__(
774879 exists (attention_heads )
775880 and exists (attention_features )
776881 and exists (attention_multiplier )
882+ and exists (attention_use_rel_pos )
777883 )
778884 self .transformer = Transformer1d (
779885 num_layers = num_transformer_blocks ,
@@ -782,6 +888,9 @@ def __init__(
782888 head_features = attention_features ,
783889 multiplier = attention_multiplier ,
784890 context_features = context_embedding_features ,
891+ use_rel_pos = attention_use_rel_pos ,
892+ rel_pos_num_buckets = attention_rel_pos_num_buckets ,
893+ rel_pos_max_distance = attention_rel_pos_max_distance ,
785894 )
786895
787896 self .post_block = ResnetBlock1d (
@@ -844,6 +953,9 @@ def __init__(
844953 context_features : Optional [int ] = None ,
845954 context_channels : Optional [Sequence [int ]] = None ,
846955 context_embedding_features : Optional [int ] = None ,
956+ attention_use_rel_pos : bool = False ,
957+ attention_rel_pos_max_distance : Optional [int ] = None ,
958+ attention_rel_pos_num_buckets : Optional [int ] = None ,
847959 ):
848960 super ().__init__ ()
849961 out_channels = default (out_channels , in_channels )
@@ -931,6 +1043,9 @@ def __init__(
9311043 attention_heads = attention_heads ,
9321044 attention_features = attention_features ,
9331045 attention_multiplier = attention_multiplier ,
1046+ attention_use_rel_pos = attention_use_rel_pos ,
1047+ attention_rel_pos_max_distance = attention_rel_pos_max_distance ,
1048+ attention_rel_pos_num_buckets = attention_rel_pos_num_buckets ,
9341049 )
9351050 for i in range (num_layers )
9361051 ]
@@ -945,6 +1060,9 @@ def __init__(
9451060 attention_heads = attention_heads ,
9461061 attention_features = attention_features ,
9471062 attention_multiplier = attention_multiplier ,
1063+ attention_use_rel_pos = attention_use_rel_pos ,
1064+ attention_rel_pos_max_distance = attention_rel_pos_max_distance ,
1065+ attention_rel_pos_num_buckets = attention_rel_pos_num_buckets ,
9481066 )
9491067
9501068 self .upsamples = nn .ModuleList (
@@ -966,6 +1084,9 @@ def __init__(
9661084 attention_heads = attention_heads ,
9671085 attention_features = attention_features ,
9681086 attention_multiplier = attention_multiplier ,
1087+ attention_use_rel_pos = attention_use_rel_pos ,
1088+ attention_rel_pos_max_distance = attention_rel_pos_max_distance ,
1089+ attention_rel_pos_num_buckets = attention_rel_pos_num_buckets ,
9691090 )
9701091 for i in reversed (range (num_layers ))
9711092 ]
0 commit comments