Skip to content

Commit 5878227

Browse files
committed
cleanup
1 parent 865d34c commit 5878227

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

alphafold3_pytorch/attention.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,6 @@ def pad_at_dim(
5555

5656
# for changing full attention bias matrix to a local windowed one for atom attention
5757

58-
@typecheck
59-
def full_attn_bias_to_windowed(
60-
attn_bias: Float['... m m'],
61-
window_size: int
62-
) -> Float['... n w (w*3)']:
63-
64-
seq_len, device = attn_bias.shape[-1], attn_bias.device
65-
66-
padding_needed = (window_size - (seq_len % window_size)) % window_size
67-
attn_bias = F.pad(attn_bias, (0, padding_needed, 0, padding_needed), value = 0.)
68-
attn_bias = rearrange(attn_bias, '... (i w1) (j w2) -> ... i j w1 w2', w1 = window_size, w2 = window_size)
69-
attn_bias = pad_at_dim(attn_bias, (1, 1), dim = -3, value = 0.)
70-
71-
attn_bias = torch.cat((
72-
attn_bias[..., :-2, :, :],
73-
attn_bias[..., 1:-1, :, :],
74-
attn_bias[..., 2:, :, :]
75-
), dim = -1)
76-
77-
# get the diagonal
78-
79-
n = torch.arange(attn_bias.shape[-3], device = device)
80-
81-
attn_bias = einx.get_at(
82-
'... [i j] w1 w2, n, n -> ... n w1 w2',
83-
attn_bias, n, n
84-
)
85-
86-
return attn_bias
87-
8858
@typecheck
8959
def full_pairwise_repr_to_windowed(
9060
pairwise_repr: Float['... m m dp'],
@@ -115,6 +85,16 @@ def full_pairwise_repr_to_windowed(
11585

11686
return pairwise_repr
11787

88+
@typecheck
89+
def full_attn_bias_to_windowed(
90+
attn_bias: Float['... m m'],
91+
window_size: int
92+
) -> Float['... n w (w*3)']:
93+
94+
attn_bias = rearrange(attn_bias, '... -> ... 1')
95+
attn_bias = full_pairwise_repr_to_windowed(attn_bias, window_size = window_size)
96+
return rearrange(attn_bias, '... 1 -> ...')
97+
11898
# multi-head attention
11999

120100
class Attention(Module):

0 commit comments

Comments
 (0)