Skip to content

Commit dece70b

Browse files
committed
bring in a custom pad fn for clarity
1 parent dbac2c6 commit dece70b

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

alphafold3_pytorch/attention.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from einops import einsum, repeat, rearrange, pack, unpack
1111
from einops.layers.torch import Rearrange
1212

13-
from alphafold3_pytorch.typing import Float, Int, Bool, typecheck
13+
from alphafold3_pytorch.typing import (
14+
Float,
15+
Int,
16+
Bool,
17+
typecheck
18+
)
1419

1520
# constants
1621

@@ -36,6 +41,18 @@ def pack_one(t, pattern):
3641
def unpack_one(t, ps, pattern):
3742
return unpack(t, ps, pattern)[0]
3843

44+
@typecheck
45+
def pad_at_dim(
46+
t,
47+
pad: Tuple[int, int],
48+
*,
49+
dim = -1,
50+
value = 0.
51+
):
52+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
53+
zeros = ((0, 0) * dims_from_right)
54+
return F.pad(t, (*zeros, *pad), value = value)
55+
3956
# multi-head attention
4057

4158
class Attention(Module):
@@ -219,7 +236,7 @@ def local_attn(
219236
padding_needed = (window_size - (seq_len % window_size)) % window_size
220237

221238
if padding_needed > 0:
222-
q, k, v = tuple(F.pad(t, (0, 0, 0, padding_needed), value = 0.) for t in (q, k, v))
239+
q, k, v = tuple(pad_at_dim(t, (0, padding_needed), value = 0., dim = -2) for t in (q, k, v))
223240
mask = F.pad(mask, (0, padding_needed), value = False)
224241

225242
# break into windows
@@ -230,7 +247,7 @@ def local_attn(
230247
# just do radius of 1 for now
231248
# perhaps not even necessary, and could try shifted windows (a la Swin)
232249

233-
k, v = tuple(F.pad(t, (0, 0, 1, 1)) for t in (k, v))
250+
k, v = tuple(pad_at_dim(t, (1, 1), dim = -2) for t in (k, v))
234251
mask = F.pad(mask, (1, 1), value = False)
235252

236253
k, v = tuple(torch.cat((t[..., :-2, :], t[..., 1:-1, :], t[..., 2:, :]), dim = -2) for t in (k, v))
@@ -241,7 +258,7 @@ def local_attn(
241258
if exists(attn_bias):
242259
attn_bias = F.pad(attn_bias, (0, padding_needed, 0, padding_needed), value = 0.)
243260
attn_bias = rearrange(attn_bias, '... (i w1) (j w2) -> ... i j w1 w2', w1 = window_size, w2 = window_size)
244-
attn_bias = F.pad(attn_bias, (0, 0, 0, 0, 1, 1), value = 0.)
261+
attn_bias = pad_at_dim(attn_bias, (1, 1), dim = -3, value = 0.)
245262

246263
attn_bias = torch.cat((
247264
attn_bias[..., :-2, :, :],

0 commit comments

Comments
 (0)