1717#
1818import abc
1919import math
20+
2021import torch
2122from einops import rearrange
2223from torch import nn
@@ -41,22 +42,23 @@ def __init__(self, dim: int, num_heads: int):
4142
4243 def forward (self , query_id , kv_id ):
4344 ind = torch .eq (query_id .unsqueeze (- 1 ), kv_id .unsqueeze (- 2 ))
44- weight = rearrange (
45- self .emb .weight , "two num_heads -> two num_heads 1 1" )
45+ weight = rearrange (self .emb .weight , "two num_heads -> two num_heads 1 1" )
4646 bias = ~ ind * weight [:1 ] + ind * weight [1 :]
4747 return bias
4848
4949
50- def _relative_position_bucket (relative_position , bidirectional = True , num_buckets = 32 , max_distance = 128 ):
50+ def _relative_position_bucket (
51+ relative_position , bidirectional = True , num_buckets = 32 , max_distance = 128
52+ ):
5153 relative_buckets = 0
5254 if bidirectional :
5355 num_buckets //= 2
54- relative_buckets += (relative_position >
55- 0 ).to (torch .long ) * num_buckets
56+ relative_buckets += (relative_position > 0 ).to (torch .long ) * num_buckets
5657 relative_position = torch .abs (relative_position )
5758 else :
58- relative_position = - \
59- torch .min (relative_position , torch .zeros_like (relative_position ))
59+ relative_position = - torch .min (
60+ relative_position , torch .zeros_like (relative_position )
61+ )
6062
6163 max_exact = num_buckets // 2
6264 is_small = relative_position < max_exact
@@ -66,12 +68,13 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
6668 * (num_buckets - max_exact )
6769 ).to (torch .long )
6870 relative_position_if_large = torch .min (
69- relative_position_if_large , torch . full_like (
70- relative_position_if_large , num_buckets - 1 )
71+ relative_position_if_large ,
72+ torch . full_like ( relative_position_if_large , num_buckets - 1 ),
7173 )
7274
73- relative_buckets += torch .where (is_small ,
74- relative_position , relative_position_if_large )
75+ relative_buckets += torch .where (
76+ is_small , relative_position , relative_position_if_large
77+ )
7578 return relative_buckets
7679
7780
@@ -83,11 +86,21 @@ def __init__(self, dim: int, num_heads: int):
8386 self .relative_attention_bias = nn .Embedding (self .num_buckets , 1 )
8487
8588 def forward (self , n_vars , n_tokens ):
86- context_position = torch .arange (n_tokens , dtype = torch .long ,)[:, None ]
87- memory_position = torch .arange (n_tokens , dtype = torch .long , )[None , :]
89+ context_position = torch .arange (
90+ n_tokens ,
91+ dtype = torch .long ,
92+ )[:, None ]
93+ memory_position = torch .arange (
94+ n_tokens ,
95+ dtype = torch .long ,
96+ )[None , :]
8897 relative_position = memory_position - context_position
89- bucket = _relative_position_bucket (relative_position = relative_position , bidirectional = False ,
90- num_buckets = self .num_buckets , max_distance = self .max_distance ).to (self .relative_attention_bias .weight .device )
98+ bucket = _relative_position_bucket (
99+ relative_position = relative_position ,
100+ bidirectional = False ,
101+ num_buckets = self .num_buckets ,
102+ max_distance = self .max_distance ,
103+ ).to (self .relative_attention_bias .weight .device )
91104 bias = self .relative_attention_bias (bucket ).squeeze (- 1 )
92105 bias = bias .reshape (1 , 1 , bias .shape [0 ], bias .shape [1 ])
93106 mask1 = torch .ones ((n_vars , n_vars ), dtype = torch .bool ).to (bias .device )
0 commit comments