@@ -55,36 +55,6 @@ def pad_at_dim(
5555
5656# for changing full attention bias matrix to a local windowed one for atom attention
5757
58- @typecheck
59- def full_attn_bias_to_windowed (
60- attn_bias : Float ['... m m' ],
61- window_size : int
62- ) -> Float ['... n w (w*3)' ]:
63-
64- seq_len , device = attn_bias .shape [- 1 ], attn_bias .device
65-
66- padding_needed = (window_size - (seq_len % window_size )) % window_size
67- attn_bias = F .pad (attn_bias , (0 , padding_needed , 0 , padding_needed ), value = 0. )
68- attn_bias = rearrange (attn_bias , '... (i w1) (j w2) -> ... i j w1 w2' , w1 = window_size , w2 = window_size )
69- attn_bias = pad_at_dim (attn_bias , (1 , 1 ), dim = - 3 , value = 0. )
70-
71- attn_bias = torch .cat ((
72- attn_bias [..., :- 2 , :, :],
73- attn_bias [..., 1 :- 1 , :, :],
74- attn_bias [..., 2 :, :, :]
75- ), dim = - 1 )
76-
77- # get the diagonal
78-
79- n = torch .arange (attn_bias .shape [- 3 ], device = device )
80-
81- attn_bias = einx .get_at (
82- '... [i j] w1 w2, n, n -> ... n w1 w2' ,
83- attn_bias , n , n
84- )
85-
86- return attn_bias
87-
8858@typecheck
8959def full_pairwise_repr_to_windowed (
9060 pairwise_repr : Float ['... m m dp' ],
@@ -115,6 +85,16 @@ def full_pairwise_repr_to_windowed(
11585
11686 return pairwise_repr
11787
88+ @typecheck
89+ def full_attn_bias_to_windowed (
90+ attn_bias : Float ['... m m' ],
91+ window_size : int
92+ ) -> Float ['... n w (w*3)' ]:
93+
94+ attn_bias = rearrange (attn_bias , '... -> ... 1' )
95+ attn_bias = full_pairwise_repr_to_windowed (attn_bias , window_size = window_size )
96+ return rearrange (attn_bias , '... 1 -> ...' )
97+
11898# multi-head attention
11999
120100class Attention (Module ):
0 commit comments