Skip to content

Commit 7a5a983

Browse files
committed
deviating from paper, allow for atom and bond embeddings
1 parent d66c56c commit 7a5a983

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,6 +2847,8 @@ def __init__(
28472847
dim_single = 384,
28482848
dim_pairwise = 128,
28492849
dim_token = 768,
2850+
num_atom_embeds: int | None = None,
2851+
num_atompair_embeds: int | None = None,
28502852
distance_bins: List[float] = torch.linspace(3, 20, 38).float().tolist(),
28512853
ignore_index = -1,
28522854
num_dist_bins: int | None = None,
@@ -2925,6 +2927,20 @@ def __init__(
29252927
):
29262928
super().__init__()
29272929

2930+
# optional atom and atom bond embeddings
2931+
2932+
has_atom_embeds = exists(num_atom_embeds)
2933+
has_atompair_embeds = exists(num_atompair_embeds)
2934+
2935+
if has_atom_embeds:
2936+
self.atom_embeds = nn.Embedding(num_atom_embeds, dim_atom)
2937+
2938+
if has_atompair_embeds:
2939+
self.atompair_embeds = nn.Embedding(num_atompair_embeds, dim_atompair)
2940+
2941+
self.has_atom_embeds = has_atom_embeds
2942+
self.has_atompair_embeds = has_atompair_embeds
2943+
29282944
# atoms per window
29292945

29302946
self.atoms_per_window = atoms_per_window
@@ -3143,6 +3159,8 @@ def forward(
31433159
atompair_inputs: Float['b m m dapi'] | Float['b nw w1 w2 dapi'],
31443160
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'],
31453161
molecule_atom_lens: Int['b n'],
3162+
atom_ids: Int['b m'] | None = None,
3163+
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
31463164
atom_mask: Bool['b m'] | None = None,
31473165
token_bond: Bool['b n n'] | None = None,
31483166
msa: Float['b s n d'] | None = None,
@@ -3217,6 +3235,23 @@ def forward(
32173235
molecule_atom_lens = molecule_atom_lens
32183236
)
32193237

3238+
# handle maybe atom and atompair embeddings
3239+
3240+
assert not (exists(atom_ids) ^ self.has_atom_embeds), 'you either set `num_atom_embeds` and did not pass in `atom_ids` or vice versa'
3241+
assert not (exists(atompair_ids) ^ self.has_atompair_embeds), 'you either set `num_atompair_embeds` and did not pass in `atompair_ids` or vice versa'
3242+
3243+
if self.has_atom_embeds:
3244+
atom_embeds = self.atom_embeds(atom_ids)
3245+
atom_feats = atom_feats + atom_embeds
3246+
3247+
if self.has_atompair_embeds:
3248+
atompair_embeds = self.atompair_embeds(atompair_ids)
3249+
3250+
if atompair_embeds.ndim == 4:
3251+
atompair_embeds = full_pairwise_repr_to_windowed(atompair_embeds, window_size = self.atoms_per_window)
3252+
3253+
atompair_feats = atompair_feats + atompair_embeds
3254+
32203255
# relative positional encoding
32213256

32223257
relative_position_encoding = self.relative_position_encoding(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.42"
3+
version = "0.1.43"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,71 @@ def test_alphafold3_force_return_loss():
635635

636636
assert loss == 0.
637637

638+
def test_alphafold3_with_atom_and_bond_embeddings():
639+
alphafold3 = Alphafold3(
640+
num_atom_embeds = 7,
641+
num_atompair_embeds = 3,
642+
dim_atom_inputs = 77,
643+
dim_template_feats = 44
644+
)
645+
646+
# mock inputs
647+
648+
seq_len = 16
649+
650+
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
651+
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
652+
653+
atom_ids = torch.randint(0, 7, (2, atom_seq_len))
654+
atompair_ids = torch.randint(0, 3, (2, atom_seq_len, atom_seq_len))
655+
656+
atom_inputs = torch.randn(2, atom_seq_len, 77)
657+
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
658+
659+
additional_molecule_feats = torch.randn(2, seq_len, 10)
660+
661+
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
662+
template_mask = torch.ones((2, 2)).bool()
663+
664+
msa = torch.randn(2, 7, seq_len, 64)
665+
msa_mask = torch.ones((2, 7)).bool()
666+
667+
# required for training, but omitted on inference
668+
669+
atom_pos = torch.randn(2, atom_seq_len, 3)
670+
molecule_atom_indices = molecule_atom_lens - 1 # last atom, as an example
671+
672+
distance_labels = torch.randint(0, 37, (2, seq_len, seq_len))
673+
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
674+
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
675+
plddt_labels = torch.randint(0, 50, (2, seq_len))
676+
resolved_labels = torch.randint(0, 2, (2, seq_len))
677+
678+
# train
679+
680+
loss = alphafold3(
681+
num_recycling_steps = 2,
682+
atom_ids = atom_ids,
683+
atompair_ids = atompair_ids,
684+
atom_inputs = atom_inputs,
685+
atompair_inputs = atompair_inputs,
686+
molecule_atom_lens = molecule_atom_lens,
687+
additional_molecule_feats = additional_molecule_feats,
688+
msa = msa,
689+
msa_mask = msa_mask,
690+
templates = template_feats,
691+
template_mask = template_mask,
692+
atom_pos = atom_pos,
693+
molecule_atom_indices = molecule_atom_indices,
694+
distance_labels = distance_labels,
695+
pae_labels = pae_labels,
696+
pde_labels = pde_labels,
697+
plddt_labels = plddt_labels,
698+
resolved_labels = resolved_labels
699+
)
700+
701+
assert loss.numel() == 0
702+
638703
# test creation from config
639704

640705
def test_alphafold3_config():

0 commit comments

Comments
 (0)