@@ -53,6 +53,38 @@ def pad_at_dim(
5353 zeros = ((0 , 0 ) * dims_from_right )
5454 return F .pad (t , (* zeros , * pad ), value = value )
5555
56+ # for changing full attention bias matrix to a local windowed one for atom attention
57+
58+ @typecheck
59+ def full_attn_bias_matrix_to_local (
60+ attn_bias : Float ['... m m' ],
61+ window_size : int
62+ ) -> Float ['... n w (w*3)' ]:
63+
64+ seq_len , device = attn_bias .shape [- 1 ], attn_bias .device
65+
66+ padding_needed = (window_size - (seq_len % window_size )) % window_size
67+ attn_bias = F .pad (attn_bias , (0 , padding_needed , 0 , padding_needed ), value = 0. )
68+ attn_bias = rearrange (attn_bias , '... (i w1) (j w2) -> ... i j w1 w2' , w1 = window_size , w2 = window_size )
69+ attn_bias = pad_at_dim (attn_bias , (1 , 1 ), dim = - 3 , value = 0. )
70+
71+ attn_bias = torch .cat ((
72+ attn_bias [..., :- 2 , :, :],
73+ attn_bias [..., 1 :- 1 , :, :],
74+ attn_bias [..., 2 :, :, :]
75+ ), dim = - 1 )
76+
77+ # get the diagonal
78+
79+ n = torch .arange (attn_bias .shape [- 3 ], device = device )
80+
81+ attn_bias = einx .get_at (
82+ '... [i j] w1 w2, n, n -> ... n w1 w2' ,
83+ attn_bias , n , n
84+ )
85+
86+ return attn_bias
87+
5688# multi-head attention
5789
5890class Attention (Module ):
@@ -218,7 +250,7 @@ def local_attn(
218250 k : Float ['b h n d' ],
219251 v : Float ['b h n d' ],
220252 mask : Bool ['b n' ] | None = None ,
221- attn_bias : Float ['... n n' ] | None = None
253+ attn_bias : Float ['... n n' ] | Float [ '... n w (w*3)' ] | None = None
222254 ) -> Float ['b h n d' ]:
223255 """
224256 simple local attention with a radius of 1 window size
@@ -233,7 +265,7 @@ def local_attn(
233265
234266 # pad to multiple of window size if needed
235267
236- padding_needed = (window_size - (seq_len % window_size )) % window_size
268+ padding_needed = (window_size - (seq_len % window_size )) % window_size
237269
238270 if padding_needed > 0 :
239271 q , k , v = tuple (pad_at_dim (t , (0 , padding_needed ), value = 0. , dim = - 2 ) for t in (q , k , v ))
@@ -255,25 +287,10 @@ def local_attn(
255287
256288 # handle attention bias (inefficiently)
257289
258- if exists (attn_bias ):
259- attn_bias = F .pad (attn_bias , (0 , padding_needed , 0 , padding_needed ), value = 0. )
260- attn_bias = rearrange (attn_bias , '... (i w1) (j w2) -> ... i j w1 w2' , w1 = window_size , w2 = window_size )
261- attn_bias = pad_at_dim (attn_bias , (1 , 1 ), dim = - 3 , value = 0. )
262-
263- attn_bias = torch .cat ((
264- attn_bias [..., :- 2 , :, :],
265- attn_bias [..., 1 :- 1 , :, :],
266- attn_bias [..., 2 :, :, :]
267- ), dim = - 1 )
290+ is_full_attn_bias = attn_bias .shape [- 1 ] == attn_bias .shape [- 2 ]
268291
269- # get the diagonal
270-
271- n = torch .arange (attn_bias .shape [- 3 ], device = device )
272-
273- attn_bias = einx .get_at (
274- '... [i j] w1 w2, n, n -> ... n w1 w2' ,
275- attn_bias , n , n
276- )
292+ if exists (attn_bias ) and is_full_attn_bias :
293+ attn_bias = full_attn_bias_matrix_to_local (attn_bias , window_size = window_size )
277294
278295 # carry out attention as usual
279296
0 commit comments