Skip to content

Commit d643da7

Browse files
committed
prepare the architecture for missing atoms, by replacing the noised atom pos features with a missing atom feature, and then removing that atom from the denoising loss and other related losses
1 parent 3875b27 commit d643da7

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,8 @@ def __init__(
17411741

17421742
self.atom_pos_to_atom_feat = LinearNoBias(3, dim_atom)
17431743

1744+
self.missing_atom_feat = nn.Parameter(torch.zeros(dim_atom))
1745+
17441746
self.single_repr_to_atom_feat_cond = nn.Sequential(
17451747
nn.LayerNorm(dim_single),
17461748
LinearNoBias(dim_single, dim_atom)
@@ -1839,7 +1841,8 @@ def forward(
18391841
pairwise_trunk: Float['b n n dpt'],
18401842
pairwise_rel_pos_feats: Float['b n n dpr'],
18411843
molecule_atom_lens: Int['b n'],
1842-
atom_parent_ids: Int['b m'] | None = None
1844+
atom_parent_ids: Int['b m'] | None = None,
1845+
missing_atom_mask: Bool['b m']| None = None
18431846
):
18441847
w = self.atoms_per_window
18451848
device = noised_atom_pos.device
@@ -1864,7 +1867,16 @@ def forward(
18641867

18651868
# the most surprising part of the paper; no geometric biases!
18661869

1867-
atom_feats = self.atom_pos_to_atom_feat(noised_atom_pos) + atom_feats
1870+
noised_atom_pos_feats = self.atom_pos_to_atom_feat(noised_atom_pos)
1871+
1872+
# for missing atoms, replace the noise atom pos features with a missing embedding
1873+
1874+
if exists(missing_atom_mask):
1875+
noised_atom_pos_feats = einx.where('b m, d, b m d -> b m d', missing_atom_mask, self.missing_atom_feat, noised_atom_pos_feats)
1876+
1877+
# sum the noised atom position features to the atom features
1878+
1879+
atom_feats = noised_atom_pos_feats + atom_feats
18681880

18691881
# condition atom feats cond (cl) with single repr
18701882

@@ -2199,6 +2211,7 @@ def forward(
21992211
pairwise_trunk: Float['b n n dpt'],
22002212
pairwise_rel_pos_feats: Float['b n n dpr'],
22012213
molecule_atom_lens: Int['b n'],
2214+
missing_atom_mask: Bool['b m'] | None = None,
22022215
atom_parent_ids: Int['b m'] | None = None,
22032216
return_denoised_pos = False,
22042217
is_molecule_types: Bool[f'b n {IS_MOLECULE_TYPES}'] | None = None,
@@ -2227,6 +2240,7 @@ def forward(
22272240
network_condition_kwargs = dict(
22282241
atom_feats = atom_feats,
22292242
atom_mask = atom_mask,
2243+
missing_atom_mask = missing_atom_mask,
22302244
atompair_feats = atompair_feats,
22312245
atom_parent_ids = atom_parent_ids,
22322246
mask = mask,
@@ -2282,6 +2296,11 @@ def forward(
22822296

22832297
losses = losses * loss_weights
22842298

2299+
# if there are missing atoms, update the atom mask to not include them in the loss
2300+
2301+
if exists(missing_atom_mask):
2302+
atom_mask = atom_mask & ~ missing_atom_mask
2303+
22852304
# account for atom mask
22862305

22872306
mse_loss = losses[atom_mask].mean()
@@ -3337,6 +3356,7 @@ def forward(
33373356
atompair_ids: Int['b m m'] | Int['b nw {self.w} {self.w*2}'] | None = None,
33383357
is_molecule_mod: Bool['b n num_mods'] | None = None,
33393358
atom_mask: Bool['b m'] | None = None,
3359+
missing_atom_mask: Bool['b m'] | None = None,
33403360
atom_parent_ids: Int['b m'] | None = None,
33413361
token_bonds: Bool['b n n'] | None = None,
33423362
msa: Float['b s n d'] | None = None,
@@ -3656,6 +3676,7 @@ def forward(
36563676
(
36573677
atom_pos,
36583678
atom_mask,
3679+
missing_atom_mask,
36593680
atom_feats,
36603681
atom_parent_ids,
36613682
atompair_feats,
@@ -3679,6 +3700,7 @@ def forward(
36793700
for t in (
36803701
atom_pos,
36813702
atom_mask,
3703+
missing_atom_mask,
36823704
atom_feats,
36833705
atom_parent_ids,
36843706
atompair_feats,
@@ -3730,6 +3752,7 @@ def forward(
37303752
atom_feats = atom_feats,
37313753
atompair_feats = atompair_feats,
37323754
atom_parent_ids = atom_parent_ids,
3755+
missing_atom_mask = missing_atom_mask,
37333756
atom_mask = atom_mask,
37343757
mask = mask,
37353758
single_trunk_repr = single,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.127"
3+
version = "0.1.128"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,13 @@ def test_distogram_head():
417417

418418
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
419419
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
420+
@pytest.mark.parametrize('missing_atoms', (True, False))
420421
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
421422
@pytest.mark.parametrize('num_molecule_mods', (0, 5))
422423
def test_alphafold3(
423424
window_atompair_inputs: bool,
424425
stochastic_frame_average: bool,
426+
missing_atoms: bool,
425427
atom_transformer_intramolecular_attn: bool,
426428
num_molecule_mods: int
427429
):
@@ -449,6 +451,10 @@ def test_alphafold3(
449451
if num_molecule_mods > 0:
450452
is_molecule_mod = torch.zeros(2, seq_len, num_molecule_mods).uniform_(0, 1) < 0.1
451453

454+
missing_atom_mask = None
455+
if missing_atoms:
456+
missing_atom_mask = torch.randint(0, 2, (2, atom_seq_len)).bool()
457+
452458
atom_parent_ids = None
453459

454460
if atom_transformer_intramolecular_attn:
@@ -501,6 +507,7 @@ def test_alphafold3(
501507
molecule_atom_lens = molecule_atom_lens,
502508
atom_parent_ids = atom_parent_ids,
503509
atompair_inputs = atompair_inputs,
510+
missing_atom_mask = missing_atom_mask,
504511
is_molecule_types = is_molecule_types,
505512
is_molecule_mod = is_molecule_mod,
506513
additional_molecule_feats = additional_molecule_feats,

0 commit comments

Comments
 (0)