1010from einops import einsum , repeat , rearrange , pack , unpack
1111from einops .layers .torch import Rearrange
1212
13- from alphafold3_pytorch .typing import Float , Int , Bool , typecheck
13+ from alphafold3_pytorch .typing import (
14+ Float ,
15+ Int ,
16+ Bool ,
17+ typecheck
18+ )
1419
1520# constants
1621
@@ -36,6 +41,18 @@ def pack_one(t, pattern):
3641def unpack_one (t , ps , pattern ):
3742 return unpack (t , ps , pattern )[0 ]
3843
44+ @typecheck
45+ def pad_at_dim (
46+ t ,
47+ pad : Tuple [int , int ],
48+ * ,
49+ dim = - 1 ,
50+ value = 0.
51+ ):
52+ dims_from_right = (- dim - 1 ) if dim < 0 else (t .ndim - dim - 1 )
53+ zeros = ((0 , 0 ) * dims_from_right )
54+ return F .pad (t , (* zeros , * pad ), value = value )
55+
3956# multi-head attention
4057
4158class Attention (Module ):
@@ -219,7 +236,7 @@ def local_attn(
219236 padding_needed = (window_size - (seq_len % window_size )) % window_size
220237
221238 if padding_needed > 0 :
222- q , k , v = tuple (F . pad (t , (0 , 0 , 0 , padding_needed ), value = 0. ) for t in (q , k , v ))
239+ q , k , v = tuple (pad_at_dim (t , (0 , padding_needed ), value = 0. , dim = - 2 ) for t in (q , k , v ))
223240 mask = F .pad (mask , (0 , padding_needed ), value = False )
224241
225242 # break into windows
@@ -230,7 +247,7 @@ def local_attn(
230247 # just do radius of 1 for now
231248 # perhaps not even necessary, and could try shifted windows (a la Swin)
232249
233- k , v = tuple (F . pad (t , (0 , 0 , 1 , 1 )) for t in (k , v ))
250+ k , v = tuple (pad_at_dim (t , (1 , 1 ), dim = - 2 ) for t in (k , v ))
234251 mask = F .pad (mask , (1 , 1 ), value = False )
235252
236253 k , v = tuple (torch .cat ((t [..., :- 2 , :], t [..., 1 :- 1 , :], t [..., 2 :, :]), dim = - 2 ) for t in (k , v ))
@@ -241,7 +258,7 @@ def local_attn(
241258 if exists (attn_bias ):
242259 attn_bias = F .pad (attn_bias , (0 , padding_needed , 0 , padding_needed ), value = 0. )
243260 attn_bias = rearrange (attn_bias , '... (i w1) (j w2) -> ... i j w1 w2' , w1 = window_size , w2 = window_size )
244- attn_bias = F . pad (attn_bias , (0 , 0 , 0 , 0 , 1 , 1 ) , value = 0. )
261+ attn_bias = pad_at_dim (attn_bias , (1 , 1 ), dim = - 3 , value = 0. )
245262
246263 attn_bias = torch .cat ((
247264 attn_bias [..., :- 2 , :, :],
0 commit comments