Skip to content

Commit b4b1ed5

Browse files
committed
take a tiny but necessary step for efficient atomic attention biasing
1 parent 244f097 commit b4b1ed5

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@
6363
typecheck
6464
)
6565

66-
from alphafold3_pytorch.attention import Attention
66+
from alphafold3_pytorch.attention import (
67+
Attention,
68+
full_attn_bias_matrix_to_local
69+
)
70+
6771
from taylor_series_linear_attention import TaylorSeriesLinearAttn
6872

6973
import einx

alphafold3_pytorch/attention.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,38 @@ def pad_at_dim(
5353
zeros = ((0, 0) * dims_from_right)
5454
return F.pad(t, (*zeros, *pad), value = value)
5555

56+
# for changing full attention bias matrix to a local windowed one for atom attention
57+
58+
@typecheck
59+
def full_attn_bias_matrix_to_local(
60+
attn_bias: Float['... m m'],
61+
window_size: int
62+
) -> Float['... n w (w*3)']:
63+
64+
seq_len, device = attn_bias.shape[-1], attn_bias.device
65+
66+
padding_needed = (window_size - (seq_len % window_size)) % window_size
67+
attn_bias = F.pad(attn_bias, (0, padding_needed, 0, padding_needed), value = 0.)
68+
attn_bias = rearrange(attn_bias, '... (i w1) (j w2) -> ... i j w1 w2', w1 = window_size, w2 = window_size)
69+
attn_bias = pad_at_dim(attn_bias, (1, 1), dim = -3, value = 0.)
70+
71+
attn_bias = torch.cat((
72+
attn_bias[..., :-2, :, :],
73+
attn_bias[..., 1:-1, :, :],
74+
attn_bias[..., 2:, :, :]
75+
), dim = -1)
76+
77+
# get the diagonal
78+
79+
n = torch.arange(attn_bias.shape[-3], device = device)
80+
81+
attn_bias = einx.get_at(
82+
'... [i j] w1 w2, n, n -> ... n w1 w2',
83+
attn_bias, n, n
84+
)
85+
86+
return attn_bias
87+
5688
# multi-head attention
5789

5890
class Attention(Module):
@@ -218,7 +250,7 @@ def local_attn(
218250
k: Float['b h n d'],
219251
v: Float['b h n d'],
220252
mask: Bool['b n'] | None = None,
221-
attn_bias: Float['... n n'] | None = None
253+
attn_bias: Float['... n n'] | Float['... n w (w*3)'] | None = None
222254
) -> Float['b h n d']:
223255
"""
224256
simple local attention with a radius of 1 window size
@@ -233,7 +265,7 @@ def local_attn(
233265

234266
# pad to multiple of window size if needed
235267

236-
padding_needed = (window_size - (seq_len % window_size)) % window_size
268+
padding_needed = (window_size - (seq_len % window_size)) % window_size
237269

238270
if padding_needed > 0:
239271
q, k, v = tuple(pad_at_dim(t, (0, padding_needed), value = 0., dim = -2) for t in (q, k, v))
@@ -255,25 +287,10 @@ def local_attn(
255287

256288
# handle attention bias (inefficiently)
257289

258-
if exists(attn_bias):
259-
attn_bias = F.pad(attn_bias, (0, padding_needed, 0, padding_needed), value = 0.)
260-
attn_bias = rearrange(attn_bias, '... (i w1) (j w2) -> ... i j w1 w2', w1 = window_size, w2 = window_size)
261-
attn_bias = pad_at_dim(attn_bias, (1, 1), dim = -3, value = 0.)
262-
263-
attn_bias = torch.cat((
264-
attn_bias[..., :-2, :, :],
265-
attn_bias[..., 1:-1, :, :],
266-
attn_bias[..., 2:, :, :]
267-
), dim = -1)
290+
is_full_attn_bias = attn_bias.shape[-1] == attn_bias.shape[-2]
268291

269-
# get the diagonal
270-
271-
n = torch.arange(attn_bias.shape[-3], device = device)
272-
273-
attn_bias = einx.get_at(
274-
'... [i j] w1 w2, n, n -> ... n w1 w2',
275-
attn_bias, n, n
276-
)
292+
if exists(attn_bias) and is_full_attn_bias:
293+
attn_bias = full_attn_bias_matrix_to_local(attn_bias, window_size = window_size)
277294

278295
# carry out attention as usual
279296

0 commit comments

Comments
 (0)