Skip to content

Commit 9fcd59f

Browse files
committed
complete the sample without replacement logic for msa
1 parent 3e9faea commit 9fcd59f

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,13 @@ def __init__(
696696
msa_pwa_dropout_row_prob = 0.15,
697697
msa_pwa_heads = 8,
698698
msa_pwa_dim_head = 32,
699-
pairwise_block_kwargs: dict = dict()
699+
pairwise_block_kwargs: dict = dict(),
700+
max_num_msa: int | None = None
700701
):
701702
super().__init__()
702703

704+
self.max_num_msa = default(max_num_msa, float('inf')) # cap the number of MSAs, will do sample without replacement if exceeds
705+
703706
self.msa_init_proj = LinearNoBias(dim_msa_input, dim_msa) if exists(dim_msa_input) else nn.Identity()
704707

705708
self.single_to_msa_feats = LinearNoBias(dim_single, dim_msa)
@@ -752,6 +755,24 @@ def forward(
752755
msa_mask: Bool['b s'] | None = None,
753756
) -> Float['b n n dp']:
754757

758+
batch, num_msa, device = *msa.shape[:2], msa.device
759+
760+
# sample without replacement
761+
762+
if num_msa > self.max_num_msa:
763+
rand = torch.randn((batch, num_msa), device = device)
764+
765+
if exists(msa_mask):
766+
rand.masked_fill_(~msa_mask, max_neg_value(msa))
767+
768+
indices = rand.topk(self.max_num_msa, dim = -1).indices
769+
770+
msa = einx.get_at('b [s] n dm, b sampled -> b sampled n dm', msa, indices)
771+
772+
if exists(msa_mask):
773+
msa_mask = einx.get_at('b [s], b sampled -> b sampled', msa_mask, indices)
774+
775+
# process msa
755776

756777
msa = self.msa_init_proj(msa)
757778

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def test_msa_module():
132132
msa = torch.randn(2, 7, 16, 64)
133133
mask = torch.randint(0, 2, (2, 16)).bool()
134134

135-
msa_module = MSAModule()
135+
msa_module = MSAModule(
136+
max_num_msa = 3 # will randomly select 3 out of the MSAs, accounting for mask, using sample without replacement
137+
)
136138

137139
pairwise_out = msa_module(
138140
msa = msa,

0 commit comments

Comments
 (0)