Skip to content

Commit 1ce98e2

Browse files
committed
add atom pooling with fixed window size for eventual atom14 support
1 parent 30bd14c commit 1ce98e2

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,30 @@ def mean_pool_with_lens(
292292
avg = einx.where('b n, b n d, -> b n d', mask, avg, 0.)
293293
return avg
294294

295+
@typecheck
296+
def mean_pool_fixed_windows_with_mask(
297+
feats: Float['b m d'],
298+
mask: Bool['b m'],
299+
window_size: int,
300+
return_pooled_mask: bool = False
301+
) -> Float['b n d'] | Tuple[Float['b n'], Bool['b n']]:
302+
303+
seq_len = feats.shape[-2]
304+
assert divisible_by(seq_len, window_size)
305+
306+
feats = einx.where('b m, b m d, -> b m d', mask, feats, 0.)
307+
308+
num = reduce(feats, 'b (n w) d -> b n d', 'sum', w = window_size)
309+
den = reduce(mask.float(), 'b (n w) -> b n 1', 'sum', w = window_size)
310+
311+
avg = num / den.clamp(min = 1.)
312+
313+
if not return_pooled_mask:
314+
return avg
315+
316+
pooled_mask = reduce(mask, 'b (n w) -> b n', 'any', w = window_size)
317+
return avg, pooled_mask
318+
295319
@typecheck
296320
def batch_repeat_interleave(
297321
feats: Float['b n ...'] | Bool['b n ...'] | Bool['b n'] | Int['b n'],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.121"
3+
version = "0.2.122"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
from alphafold3_pytorch.alphafold3 import (
4545
mean_pool_with_lens,
46+
mean_pool_fixed_windows_with_mask,
4647
batch_repeat_interleave,
4748
full_pairwise_repr_to_windowed,
4849
get_cid_molecule_type,
@@ -75,6 +76,14 @@ def test_mean_pool_with_lens():
7576

7677
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
7778

79+
def test_mean_pool_with_mask():
80+
seq = torch.tensor([[[1.], [100.], [1.], [2.], [2.], [100.], [1.], [1.], [100.]]])
81+
mask = torch.tensor([[True, False, True, True, True, False, True, True, False]])
82+
83+
pooled = mean_pool_fixed_windows_with_mask(seq, mask, window_size = 3)
84+
85+
assert torch.allclose(pooled, torch.tensor([[[1.], [2.], [1.]]]))
86+
7887
def test_batch_repeat_interleave():
7988
seq = torch.tensor([[[1.], [2.], [4.]], [[1.], [2.], [4.]]])
8089
lens = torch.tensor([[3, 4, 2], [2, 5, 1]]).long()

0 commit comments

Comments
 (0)