Skip to content

Commit 291f715

Browse files
committed
complete the function necessary to broadcast token features back to atom features for the packed representation
1 parent 57f6da2 commit 291f715

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ def inner(t, *args, **kwargs):
8989

9090
# packed atom representation functions
9191

92+
@typecheck
93+
def lens_to_mask(
94+
lens: Int['b n'] | Int[' b']
95+
) -> Bool['b m']:
96+
97+
device = lens.device
98+
99+
if lens.ndim == 2:
100+
lens = reduce(lens, 'b m -> b', 'sum')
101+
102+
max_len = lens.amax()
103+
arange = torch.arange(max_len, device = device)
104+
return einx.less('m, b -> b m', arange, lens)
105+
92106
@typecheck
93107
def mean_pool_with_lens(
94108
feats: Float['b m d'],
@@ -115,6 +129,51 @@ def mean_pool_with_lens(
115129
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
116130
return avg
117131

132+
@typecheck
133+
def repeat_consecutive_with_lens(
134+
feats: Float['b n d'],
135+
lens: Int['b n'],
136+
max_length: int | None = None,
137+
return_mask = False
138+
) -> Float['b m d'] | Tuple[Float['b m d'], Bool['b m']]:
139+
140+
device = feats.device
141+
142+
# derive arange from the max length
143+
144+
total_lens = reduce(lens, 'b n -> b', 'sum')
145+
146+
if not exists(max_length):
147+
max_length = total_lens.amax()
148+
149+
arange = torch.arange(max_length, device = device)
150+
151+
# get packed atom mask from the total lengths
152+
153+
mask = lens_to_mask(total_lens)
154+
155+
lens = F.pad(lens, (1, 0), value = 0)
156+
cumsum_lens = lens.cumsum(dim = -1)
157+
left_index, right_index = cumsum_lens[:, :-1], cumsum_lens[:, 1:]
158+
159+
# derive the mask for consecutives per feat
160+
161+
left_mask = einx.greater_equal('m, b n -> b n m', arange, left_index)
162+
right_mask = einx.less('m, b n -> b n m', arange, right_index)
163+
164+
consecutive_mask = left_mask & right_mask
165+
166+
# now broadcast and sum for consecutive features
167+
168+
feats = einx.multiply('b n d, b n m -> b n m d', feats, consecutive_mask.float())
169+
feats = reduce(feats, 'b n m d -> b m d', 'sum')
170+
171+
if not return_mask:
172+
return feats
173+
174+
mask = mask[:, :max_length]
175+
return feats, mask
176+
118177
# linear and outer sum
119178
# for single repr -> pairwise pattern throughout this architecture
120179

tests/test_af3.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
os.environ['TYPECHECK'] = 'True'
33

44
import torch
5-
from torch.nn.utils.rnn import pad_sequence
6-
75
import pytest
86

97
from alphafold3_pytorch import (
@@ -27,7 +25,8 @@
2725
)
2826

2927
from alphafold3_pytorch.alphafold3 import (
30-
mean_pool_with_lens
28+
mean_pool_with_lens,
29+
repeat_consecutive_with_lens
3130
)
3231

3332
def test_mean_pool_with_lens():
@@ -37,6 +36,13 @@ def test_mean_pool_with_lens():
3736

3837
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
3938

39+
def test_repeat_consecutive_with_lens():
40+
seq = torch.tensor([[[1.], [2.], [4.]]])
41+
lens = torch.tensor([[3, 4, 2]]).long()
42+
repeated = repeat_consecutive_with_lens(seq, lens)
43+
44+
assert torch.allclose(repeated, torch.tensor([[[1.], [1.], [1.], [2.], [2.], [2.], [2.], [4.], [4.]]]))
45+
4046
def test_smooth_lddt_loss():
4147
pred_coords = torch.randn(2, 100, 3)
4248
true_coords = torch.randn(2, 100, 3)

0 commit comments

Comments
 (0)