Skip to content

Commit 43e3db6

Browse files
committed
handle any type of molecule modifications (phospho, glycosy, methyl) all in one swoop, as a sparse boolean tensor
1 parent 3b0584a commit 43e3db6

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2967,6 +2967,7 @@ def __init__(
29672967
num_molecule_types: int = 32, # restype in additional residue information, apparently 32 (must be human amino acids + nucleotides + something else)
29682968
num_atom_embeds: int | None = None,
29692969
num_atompair_embeds: int | None = None,
2970+
num_molecule_mods: int | None = None,
29702971
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
29712972
ignore_index = -1,
29722973
num_dist_bins: int | None = None,
@@ -3064,6 +3065,16 @@ def __init__(
30643065
self.has_atom_embeds = has_atom_embeds
30653066
self.has_atompair_embeds = has_atompair_embeds
30663067

3068+
# residue or nucleotide modifications
3069+
3070+
has_molecule_mod_embeds = num_molecule_mods > 0
3071+
self.num_molecule_mods = num_molecule_mods
3072+
3073+
if has_molecule_mod_embeds:
3074+
self.molecule_mod_embeds = nn.Embedding(num_molecule_mods, dim_single)
3075+
3076+
self.has_molecule_mod_embeds = has_molecule_mod_embeds
3077+
30673078
# atoms per window
30683079

30693080
self.atoms_per_window = atoms_per_window
@@ -3302,6 +3313,7 @@ def forward(
33023313
additional_token_feats: Float['b n {self.dim_additional_token_feats}'] | None = None,
33033314
atom_ids: Int['b m'] | None = None,
33043315
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
3316+
is_molecule_mod: Bool['b n {self.num_molecule_mods}'] | None = None,
33053317
atom_mask: Bool['b m'] | None = None,
33063318
atom_parent_ids: Int['b m'] | None = None,
33073319
token_bonds: Bool['b n n'] | None = None,
@@ -3401,6 +3413,26 @@ def forward(
34013413

34023414
atompair_feats = atompair_feats + atompair_embeds
34033415

3416+
# handle maybe molecule modifications
3417+
3418+
assert not (exists(is_molecule_mod) ^ self.has_molecule_mod_embeds), 'you either set `num_molecule_mods` and did not pass in `is_molecule_mod` or vice versa'
3419+
3420+
if self.has_molecule_mod_embeds:
3421+
single_init, seq_packed_shape = pack_one(single_init, '* ds')
3422+
3423+
is_molecule_mod, _ = pack_one(is_molecule_mod, '* mods')
3424+
3425+
if not is_molecule_mod.is_sparse:
3426+
is_molecule_mod = is_molecule_mod.to_sparse()
3427+
3428+
seq_indices, mod_id = is_molecule_mod.indices()
3429+
scatter_values = self.molecule_mod_embeds(mod_id)
3430+
3431+
seq_indices = repeat(seq_indices, 'n -> n ds', ds = single_init.shape[-1])
3432+
single_init = single_init.scatter_add(0, seq_indices, scatter_values)
3433+
3434+
single_init = unpack_one(single_init, seq_packed_shape, '* ds')
3435+
34043436
# relative positional encoding
34053437

34063438
relative_position_encoding = self.relative_position_encoding(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.108"
3+
version = "0.1.109"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,12 @@ def test_distogram_head():
418418
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
419419
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
420420
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
421+
@pytest.mark.parametrize('num_molecule_mods', (0, 5))
421422
def test_alphafold3(
422423
window_atompair_inputs: bool,
423424
stochastic_frame_average: bool,
424-
atom_transformer_intramolecular_attn: bool
425+
atom_transformer_intramolecular_attn: bool,
426+
num_molecule_mods: int
425427
):
426428
seq_len = 16
427429
atoms_per_window = 27
@@ -443,6 +445,10 @@ def test_alphafold3(
443445
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
444446
molecule_ids = torch.randint(0, 32, (2, seq_len))
445447

448+
is_molecule_mod = None
449+
if num_molecule_mods > 0:
450+
is_molecule_mod = torch.zeros(2, seq_len, num_molecule_mods).uniform_(0, 1) < 0.1
451+
446452
atom_parent_ids = None
447453

448454
if atom_transformer_intramolecular_attn:
@@ -495,6 +501,7 @@ def test_alphafold3(
495501
atom_parent_ids = atom_parent_ids,
496502
atompair_inputs = atompair_inputs,
497503
is_molecule_types = is_molecule_types,
504+
is_molecule_mod = is_molecule_mod,
498505
additional_molecule_feats = additional_molecule_feats,
499506
additional_token_feats = additional_token_feats,
500507
token_bonds = token_bonds,
@@ -519,6 +526,7 @@ def test_alphafold3(
519526
atom_inputs = atom_inputs,
520527
molecule_ids = molecule_ids,
521528
molecule_atom_lens = molecule_atom_lens,
529+
is_molecule_mod = is_molecule_mod,
522530
atompair_inputs = atompair_inputs,
523531
is_molecule_types = is_molecule_types,
524532
additional_molecule_feats = additional_molecule_feats,

0 commit comments

Comments
 (0)